# Instalación

In [1]:
!pip install numpy scipy
!pip install ttml

!pip install t3f

import os

# Ruta al archivo initializers.py
init_file_path = '/usr/local/lib/python3.11/dist-packages/t3f/initializers.py'

# Nueva versión de random_tensor con docstring correctamente formateado
new_random_tensor = '''def random_tensor(shape, tt_rank=2, mean=0., stddev=1., dtype=tf.float32,
                  name='t3f_random_tensor'):
    """Generate a random TT-tensor of the given shape with given mean and stddev.

    Entries of the generated tensor (in the full format) will be iid and satisfy
    E[x_{i1i2..id}] = mean, Var[x_{i1i2..id}] = stddev^2, but the distribution is
    in fact not Gaussian (but is close for large tensors).

    In the current implementation only mean 0 is supported. To get
    a random_tensor with specified mean but tt_rank greater by 1 you can
    call:
        x = t3f.random_tensor(shape, tt_rank, stddev=stddev)
        x = mean * t3f.ones_like(x) + x

    Args:
        shape: array representing the shape of the future tensor.
        tt_rank: a number or a (d+1)-element array with the desired ranks.
        mean: a number, the desired mean for the distribution of entries.
        stddev: a number, the desired standard deviation for the distribution of
            entries.
        dtype: [tf.float32] dtype of the resulting tensor.
        name: string, name of the Op.

    Returns:
        TensorTrain containing a TT-tensor
    """
    shape = np.array(shape)
    tt_rank = np.array(tt_rank)
    _validate_input_parameters(is_tensor=True, shape=shape, tt_rank=tt_rank)

    num_dims = shape.size
    if tt_rank.size == 1:
        tt_rank = tt_rank * np.ones(num_dims - 1, dtype=int)
        tt_rank = np.insert(tt_rank, 0, 1)
        tt_rank = np.append(tt_rank, 1)

    tt_rank = tt_rank.astype(int)

    # Empirically entries of a TT tensor with cores initialized from N(0, 1)
    # will have variances np.prod(tt_rank) and mean 0.
    # We scale each TT-core to obtain the desired stddev
    cr_exponent = -1.0 / (2 * num_dims)
    var = np.prod(tt_rank ** cr_exponent)
    core_stddev = stddev ** (1.0 / num_dims) * var
    with tf.name_scope(name):
        tt = tensor_with_random_cores(shape, tt_rank=tt_rank, stddev=core_stddev,
                                      dtype=dtype)

    if np.abs(mean) < 1e-8:
        return tt
    else:
        raise NotImplementedError('non-zero mean is not supported yet')
'''

# Verificar si el archivo existe y modificarlo
if os.path.exists(init_file_path):
    # Leer el contenido actual
    with open(init_file_path, 'r') as f:
        content = f.read()

    # 1. Añadir el nuevo import después de "from t3f import shapes"
    import_line = 'from t3f import shapes'
    new_import = 'from numbers import Integral as integer'
    if new_import not in content:
        # Asegurarse de que el import se añada solo una vez y en la posición correcta
        content = content.replace(import_line, f"{import_line}\n{new_import}", 1)

    # 2. Reemplazar la función random_tensor de manera más robusta
    start_marker = 'def random_tensor('
    try:
        start_idx = content.index(start_marker)
        # Buscar el final de la función buscando la siguiente definición o el final del archivo
        next_func_marker = 'def '
        end_idx = content.find(next_func_marker, start_idx + len(start_marker))
        if end_idx == -1:  # Si no hay más funciones, ir al final del archivo
            end_idx = len(content)
        else:
            # Retroceder hasta encontrar el final real de la función (antes de la próxima definición)
            end_idx = content.rfind('\n\n', start_idx, end_idx)
            if end_idx == -1:
                end_idx = content.rfind('\n', start_idx, end_idx)

        # Construir el nuevo contenido
        new_content = content[:start_idx] + new_random_tensor + content[end_idx:]

        # Escribir el contenido modificado
        with open(init_file_path, 'w') as f:
            f.write(new_content)

        print(f"El archivo {init_file_path} ha sido modificado exitosamente!")
    except ValueError as e:
        print(f"Error al encontrar marcadores en el archivo: {e}")
else:
    print(f"El archivo {init_file_path} no existe o no es accesible.")

import tensorflow as tf

# Configurar GPU al inicio, antes de cualquier operación de TensorFlow
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
    try:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        print("GPU detectada y configurada para TensorFlow")
    except RuntimeError as e:
        print(f"Advertencia: No se pudo configurar el crecimiento de memoria en la GPU: {e}")
else:
    print("No se detectó GPU, ejecutando en CPU")

Collecting ttml
  Downloading ttml-1.0-py3-none-any.whl.metadata (3.0 kB)
Collecting autoray (from ttml)
  Downloading autoray-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Downloading ttml-1.0-py3-none-any.whl (97 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m97.1/97.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading autoray-0.7.1-py3-none-any.whl (930 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m930.8/930.8 kB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: autoray, ttml
Successfully installed autoray-0.7.1 ttml-1.0
Collecting t3f
  Downloading t3f-1.2.0.tar.gz (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.0/58.0 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: t3f
  Building wheel for t3f (setup.py) ... [?25l[?25hdone
  Created wheel for t3f: filename=t3f-1.2.0-py3-none-any.w

# Funciones Auxiliares

In [6]:
import t3f
from ttml.tensor_train import TensorTrain
import numpy as np
import torch
import matplotlib.pyplot as plt

# --- Helper function for error calculation ---
def calculate_relative_l2_error(y_true, y_pred):
    """Calculates the relative L2 error."""
    if not isinstance(y_true, np.ndarray): y_true = np.asarray(y_true)
    if not isinstance(y_pred, np.ndarray): y_pred = np.asarray(y_pred)

    if y_true.size == 0 or y_pred.size == 0 or y_true.shape != y_pred.shape:
        # print(f"Warning: Mismatch in y_true/y_pred for relative L2 error. y_true: {y_true.shape}, y_pred: {y_pred.shape}")
        return np.inf

    norm_true = np.linalg.norm(y_true)
    if norm_true < 1e-12: # Avoid division by zero or very small numbers
        return np.linalg.norm(y_pred) if np.linalg.norm(y_pred) > 1e-12 else 0.0
    return np.linalg.norm(y_pred - y_true) / norm_true

def plot_loss_history(loss_hist, val_loss_hist, loss_metric_name='Relative L2 Error', title='Optimization Loss History'):
    """
    Grafica el historial de pérdida de entrenamiento y validación en función de la iteración.

    Args:
        loss_hist (list): Lista de valores de pérdida de entrenamiento por iteración.
        val_loss_hist (list): Lista de valores de pérdida de validación por iteración.
        loss_metric_name (str, optional): Nombre de la métrica de pérdida para las etiquetas del gráfico.
                                           Por defecto es 'Relative L2 Error'.
        title (str, optional): Título del gráfico. Por defecto es 'Optimization Loss History'.
    """

    # Las iteraciones comienzan desde -1 (para el estado inicial)
    # y luego 0, 1, 2, ... hasta len(loss_hist) - 2.
    # El eje X debe representar las iteraciones reales.
    # Si loss_hist[0] es la pérdida inicial (Iteración -1)
    # entonces loss_hist[1] es la pérdida de la Iteración 0, etc.
    iterations = np.arange(len(loss_hist)) - 1 # Ajustar para que el primer punto sea -1

    plt.figure(figsize=(12, 6))

    plt.plot(iterations, loss_hist, label=f'Train Loss ({loss_metric_name})', color='blue', alpha=0.8)
    plt.plot(iterations, val_loss_hist, label=f'Validation Loss ({loss_metric_name})', color='red', alpha=0.8)

    plt.title(title)
    plt.xlabel('Iteration')
    plt.ylabel(loss_metric_name)
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.yscale('log') # Es común que las pérdidas se grafiquen en escala logarítmica para ver la convergencia
    plt.legend()
    plt.tight_layout()
    plt.show()

def process_tt_approximation_and_convert(
    best_tt_approx: TensorTrain,
    increase_amount: int = 0
) -> t3f.TensorTrain | None:
    """
    Takes a ttml.TensorTrain approximation, optionally increases its rank,
    and then attempts to convert it to a t3f.TensorTrain object.

    Args:
        best_tt_approx: The ttml.TensorTrain approximation object.
                        It is expected to be a valid ttml.TensorTrain instance.
        increase_amount (int, optional): The amount by which to increase the
                                         TT rank of the approximation. Defaults to 0 (no change).

    Returns:
        t3f.TensorTrain | None: The converted t3f.TensorTrain object if conversion is
                                successful, otherwise None.
    """
    if not isinstance(best_tt_approx, TensorTrain):
        print("\nError: 'best_tt_approx' is not a valid ttml.TensorTrain object. Conversion skipped.")
        return None

    print(f"Rango TT inicial: {best_tt_approx.tt_rank}")

    # Increase the rank of all connections by 'increase_amount' units
    if increase_amount > 0:
        best_tt_approx.increase_rank(increase_amount)
        print(f"Rango TT después de aumentar en {increase_amount} unidades: {best_tt_approx.tt_rank}")
    else:
        print("No se solicitó aumento de rango (increase_amount es 0 o menor).")


    best_tt_approx_t3f = convertir_ttml_a_t3f(best_tt_approx)

    if best_tt_approx_t3f is not None:
        print("\nEl objeto best_tt_approx ha sido convertido exitosamente a t3f en la variable best_tt_approx_t3f.")
    else:
        print("\nLa conversión de best_tt_approx a t3f falló.")

    print("\n--- Fin del proceso de Conversión ---")

    return best_tt_approx_t3f

def augment_training_set(
    Omega: tf.Tensor | np.ndarray,
    A_Omega: tf.Tensor | np.ndarray,
    sizeOmegaExtra: int,
    number_nodes: int,
    d: int,
    custom_function: callable,
    seed: int = None # Added new argument for reproducibility
) -> tuple[tf.Tensor, tf.Tensor]:
    """
    Augments an existing training set (Omega, A_Omega) with additional data
    if sizeOmegaExtra is greater than zero.

    Args:
        Omega: Initial multi-indices of the training set. Can be tf.Tensor or np.ndarray.
        A_Omega: Initial true values corresponding to Omega. Can be tf.Tensor or np.ndarray.
        sizeOmegaExtra (int): The number of extra points to add to the training set.
                              If 0, no extra points are generated.
        number_nodes (int): The total number of nodes in each dimension, used by
                            make_omega_set and map_to_chebyshev_nodes.
        d (int): The dimensionality of the multi-indices.
        custom_function (callable): A function that takes Chebyshev nodes (np.ndarray)
                                    and returns the corresponding true values (np.ndarray).
        seed (int, optional): A seed for the random number generator to ensure reproducibility
                              when generating extra indices. If None, a random seed is used.

    Returns:
        tuple[tf.Tensor, tf.Tensor]: A tuple containing the augmented training indices (Omega)
                                     and augmented training values (A_Omega), both as
                                     TensorFlow tensors with appropriate dtypes.

    Raises:
        ValueError: If custom_function returns extra_evaluations that are not 1D or
                    a 2D array with a single column, or if A_Omega is not 1D or
                    a 2D array with a single column.
    """

    # Create a RandomState object for reproducibility
    rng = np.random.RandomState(seed)

    # Convert initial Omega and A_Omega to NumPy arrays if they are TensorFlow tensors
    if tf.is_tensor(Omega):
        Omega_np = Omega.numpy()
    else:
        Omega_np = Omega

    if tf.is_tensor(A_Omega):
        A_Omega_np = A_Omega.numpy()
    else:
        A_Omega_np = A_Omega

    # Ensure A_Omega_np is 1D and of type float64
    if A_Omega_np.ndim > 1:
        if A_Omega_np.shape[1] == 1:
            A_Omega_np = A_Omega_np.flatten().astype(np.float64)
        else:
            raise ValueError("A_Omega (initial evaluations) must be 1D or a 2D array with one column.")
    else:
        A_Omega_np = A_Omega_np.astype(np.float64) # Ensure dtype=float64

    extra_indices = np.empty((0, d), dtype=int)
    extra_evaluations = np.empty((0,), dtype=np.float64)

    # Generate extra training data if sizeOmegaExtra > 0
    if sizeOmegaExtra > 0:
        print(f"Attempting to add {sizeOmegaExtra} extra points to the training set...")
        # Step 1: Create extra set of multi-indices, now passing the random_state
        extra_indices = make_omega_set(number_nodes, sizeOmegaExtra, d, random_state=rng)

        # Step 2: Map to Chebyshev nodes
        extra_nodes = map_to_chebyshev_nodes(extra_indices, number_nodes)

        # Step 3: Evaluate custom_function on those nodes
        extra_evaluations = custom_function(extra_nodes)

        # Ensure extra_evaluations is 1D and of type float64
        if extra_evaluations.ndim > 1:
            if extra_evaluations.shape[1] == 1:
                extra_evaluations = extra_evaluations.flatten().astype(np.float64)
            else:
                raise ValueError("custom_function must return a 1D array or a 2D array with one column for extra_evaluations.")
        else:
            extra_evaluations = extra_evaluations.astype(np.float64) # Ensure dtype=float64
        print(f"Successfully generated {len(extra_indices)} extra points.")
    else:
        print("sizeOmegaExtra is 0, no extra points will be added to the training set.")

    # Step 4: Concatenate indices
    # Ensure Omega_np is 2D, even if it's empty, for vstack
    if Omega_np.ndim == 1:
        # Assuming d is the correct dimension if Omega_np is 1D and needs reshaping
        # This might need a more robust check based on actual data
        Omega_np = Omega_np.reshape(-1, d)

    Omega_indices_augmented_np = np.vstack((Omega_np, extra_indices))

    # Step 5: Concatenate evaluations (maintain 1D)
    A_Omega_augmented_np = np.hstack((A_Omega_np, extra_evaluations))

    # Convert to TensorFlow tensors with specified dtypes
    Omega_tf = tf.constant(Omega_indices_augmented_np, dtype=tf.int32)
    A_Omega_tf = tf.constant(A_Omega_augmented_np, dtype=tf.float64)

    print(f"Training set augmented. New size of Omega: {len(Omega_tf)}.")

    return Omega_tf, A_Omega_tf

def make_omega_set_sobol(n_nodes, size, d, seed=None):
    """
    Genera un conjunto de multiíndices cuasialeatorios con d componentes,
    cada uno entre 0 y n_nodes - 1, usando Sobol sequences.

    Parameters:
    - n_nodes (int): Límite superior para cada índice (rango es 0 a n_nodes - 1).
    - size (int): Número de multiíndices a generar.
    - d (int): Número de dimensiones.
    - seed (Optional[int]): Semilla para el motor Sobol para reproducibilidad.

    Returns:
    - np.array: Array 2D de forma (size, d) con multiíndices cuasialeatorios.
    """
    if not isinstance(n_nodes, int) or n_nodes < 1:
        raise ValueError("n_nodes must be a positive integer")
    if not isinstance(size, int) or size < 0:
        raise ValueError("size must be a non-negative integer")
    if not isinstance(d, int):
        raise ValueError("d must be an integer")

    if not (1 <= d <= 1111):
        raise ValueError(f"SobolEngine dimension d must be between 1 and 1111, inclusive. Got {d}")

    if size == 0:
        return np.empty((0, d), dtype=int)

    # Esta comprobación estaba en la función original.
    # np.random.randint permitiría size > total_combinations (produciendo duplicados).
    # Sobol discretizado también producirá duplicados si size > total_combinations.
    # Mantenemos la comprobación para que el comportamiento de error sea idéntico.
    total_combinations = n_nodes ** d
    if size > total_combinations:
        raise ValueError(f"Requested size ({size}) exceeds total possible combinations ({total_combinations})")

    # Inicializar SobolEngine
    engine = torch.quasirandom.SobolEngine(dimension=d, scramble=True, seed=seed)

    # Generar 'size' muestras en [0, 1)^d
    samples_unit_cube = engine.draw(n=size)

    # Escalar a [0, n_nodes)^d y discretizar tomando el suelo.
    # Usamos float(n_nodes) para asegurar la multiplicación flotante.
    omega_torch = torch.floor(samples_unit_cube * float(n_nodes))

    # Los valores de Sobol están en [0,1), así que sample * n_nodes está en [0, n_nodes).
    # torch.floor mapea esto a enteros {0, 1, ..., n_nodes - 1}.
    # No es necesario clamping adicional si n_nodes es >= 1.

    # Convertir a array NumPy de enteros
    Omega = omega_torch.long().numpy()

    return Omega

def generate_disjoint_omega_c_sobol(Omega, sizeOmega_C, number_nodes, seed=None):
    """
    Genera un conjunto Omega_C disjunto que no se solapa con Omega,
    usando Sobol sequences.

    Parameters:
    - Omega (np.ndarray or tf.Tensor): Índices existentes (conjunto de entrenamiento).
    - sizeOmega_C (int): Tamaño del nuevo conjunto Omega_C.
    - number_nodes (int): Número de nodos por dimensión.
    - seed (Optional[int]): Semilla para el motor Sobol para reproducibilidad.

    Returns:
    - np.array: Nuevo conjunto Omega_C de forma (sizeOmega_C, d).
    """
    if TF_AVAILABLE and isinstance(Omega, tf.Tensor):
        Omega_np = Omega.numpy()
    elif isinstance(Omega, np.ndarray):
        Omega_np = Omega
    else:
        raise TypeError("Omega must be a NumPy array or TensorFlow tensor.")

    if not isinstance(sizeOmega_C, int) or sizeOmega_C < 0:
        raise ValueError("sizeOmega_C must be a non-negative integer")
    if not isinstance(number_nodes, int) or number_nodes < 1:
        raise ValueError("number_nodes must be a positive integer")

    # Determinar la dimensionalidad d desde Omega.
    # Si Omega está vacío pero tiene forma (0, d_val), d_val se usa.
    # Si Omega es (N,0), d=0, lo cual dará error en la validación de Sobol.
    if Omega_np.ndim != 2:
        # Si Omega es, por ejemplo, un array vacío [], Omega_np.shape puede ser (0,).
        # O si sizeOmega_C es 0 y Omega es [], d podría no estar bien definido.
        # Si Omega está vacío y sizeOmega_C > 0, necesitamos una 'd' explícita o inferirla.
        # Asumimos que si Omega está vacío (shape (0,)), necesitamos que sizeOmega_C sea 0
        # o que d se provea de otra forma. Aquí, d se infiere de Omega.shape[1].
        # Si Omega = np.array([]), .shape es (0,). Esto llevaría a d=0 si se toma shape[1] directamente.
        # Para ser robustos, si Omega_np.shape es (0,), y sizeOmega_C > 0, no podemos inferir d.
        if Omega_np.shape == (0,) and sizeOmega_C > 0:
             raise ValueError("Cannot infer dimension d from an empty Omega with shape (0,).")
        elif Omega_np.shape == (0,) and sizeOmega_C == 0: # Si Omega es [] y sizeOmega_C es 0
            # No se puede inferir d, pero no se necesitan puntos. Devolvemos (0,0) o (0,1) arbitrariamente?
            # Para ser consistentes, necesitamos una 'd'. Podríamos requerir d como parámetro o
            # decidir que un Omega (0,) no es un input válido si sizeOmega_C > 0.
            # La función original haría Omega.shape[1], lo que fallaría para (0,).
            # Asumimos que Omega siempre será 2D, e.g. np.empty((0, d_actual)).
            raise ValueError("Omega has an unexpected shape. Expected 2D array.")

    d = Omega_np.shape[1]

    if not (1 <= d <= 1111):
        # Manejar el caso d=0 que Sobol no soporta:
        # Si d=0 y sizeOmega_C=0, devolver np.empty((0,0)) es razonable.
        # Si d=0 y sizeOmega_C=1 y Omega no contiene '()', se podría devolver np.array([[]]).
        # Para simplificar y adherirse al rango de Sobol:
        if d == 0 and sizeOmega_C == 0: # Caso especial si la forma de Omega era (N,0)
            return np.empty((0,0), dtype=int)
        raise ValueError(f"SobolEngine dimension d must be between 1 and 1111, inclusive. Got {d} from Omega.shape[1]")

    if sizeOmega_C == 0:
        return np.empty((0, d), dtype=int)

    total_points = number_nodes ** d

    # Convertir Omega a un conjunto de tuplas para búsquedas rápidas
    all_indices = set(tuple(idx) for idx in Omega_np)

    # Comprobar si es posible generar la cantidad solicitada
    if len(all_indices) + sizeOmega_C > total_points:
        raise ValueError(f"Cannot generate {sizeOmega_C} unique indices. "
                         f"Available points: {total_points - len(all_indices)}."
                         " Requested: {sizeOmega_C}.")

    # Límite de intentos como en la función original
    max_attempts = 10 * sizeOmega_C # O un valor más grande si se espera alta tasa de colisión

    Omega_C_list = []
    attempts = 0

    engine = torch.quasirandom.SobolEngine(dimension=d, scramble=True, seed=seed)

    while len(Omega_C_list) < sizeOmega_C and attempts < max_attempts:
        # Generar un candidato usando Sobol
        # sample_unit_cube tiene forma (1, d)
        sample_unit_cube = engine.draw(n=1)

        candidate_torch = torch.floor(sample_unit_cube * float(number_nodes))
        # Asegurar que esté en el rango [0, number_nodes-1]
        # No es estrictamente necesario el clamp con Sobol [0,1) y floor.
        # candidate_torch = torch.clamp(candidate_torch, min=0, max=number_nodes - 1)

        candidate = candidate_torch.long().numpy()[0] # Extraer el array 1D de la forma (1,d)

        if tuple(candidate) not in all_indices:
            Omega_C_list.append(candidate)
            all_indices.add(tuple(candidate))

        attempts += 1

    if len(Omega_C_list) < sizeOmega_C:
        raise ValueError("Could not generate enough unique indices for Omega_C within max_attempts.")

    # Convertir lista de arrays a un array 2D NumPy
    # Si Omega_C_list está vacía (aunque sizeOmega_C > 0 y falló), np.array([]) da forma (0,).
    # Pero ya tenemos el chequeo de sizeOmega_C == 0 arriba.
    # Si len(Omega_C_list) < sizeOmega_C, se lanza error.
    # Así que si llegamos aquí, len(Omega_C_list) == sizeOmega_C > 0.
    return np.array(Omega_C_list, dtype=int)

def format_value_dynamically(value, fixed_digits=4, scientific_digits=4):
    """
    Formatea un valor numérico.
    Usa notación científica si el valor es muy pequeño (pero no cero)
    y se redondearía a cero con la notación de punto fijo especificada.
    De lo contrario, usa notación de punto fijo.

    Args:
        value (float): El número a formatear.
        fixed_digits (int): Número de decimales para la notación de punto fijo.
        scientific_digits (int): Número de decimales para la mantisa en notación científica.
                                 (ej., .4e produce X.YYYYeNN)
    Returns:
        str: El número formateado como cadena.
    """
    # El umbral para cambiar a notación científica.
    # Si abs(value) es menor que 0.5 * 10^(-fixed_digits), entonces :.Nf lo redondeará a 0.0...0.
    # Por ejemplo, para fixed_digits=4 (0.0001), el umbral es 0.00005.
    threshold = 0.5 * (10**-fixed_digits)

    if 0 < abs(value) < threshold:
        return f"{value:.{scientific_digits}e}"  # Notación científica, ej: "1.2345e-08"
    else:
        return f"{value:.{fixed_digits}f}"     # Notación de punto fijo, ej: "0.0001"

# Modified physical_point_index_fun to COLLECT indices and values
# Collection happens HERE.
def physical_point_index_fun(indices_array): # Argument points_1d removed.
    """
    Converts index arrays (0 to MODAL_SIZE-1) to physical points (Chebyshev nodes in [0,1])
    using map_to_chebyshev_nodes.
    CALLS custom_function with physical points.
    COLLECTS the input index/output value pairs for this batch.
    Uses the global MODAL_SIZE and collects into global lists.
    """
    # Use global lists for collection
    global collected_indices, collected_evaluations
    global d # For dimension check
    global MODAL_SIZE # For mapping

    # --- Robust Input Handling ---
    # Ensure indices_input is a 2D array (num_samples, d)
    indices_input_processed = None
    num_samples_in_batch = 0

    if isinstance(indices_array, np.ndarray):
        if indices_array.ndim == 1:
            indices_input_processed = indices_array.reshape(1, -1)
            num_samples_in_batch = 1
        elif indices_array.ndim == 2:
            indices_input_processed = indices_array
            num_samples_in_batch = indices_array.shape[0]
        elif indices_array.ndim == 3 and indices_array.shape[1] == 1:
            # Flatten (num_samples, 1, d) to (num_samples, d)
            indices_input_processed = indices_array[:, 0, :]
            num_samples_in_batch = indices_array.shape[0]
        # Add checks for other potential shapes if observed
        else:
            # Attempt to reshape to 2D (num_samples, d) based on total elements
            total_elements = indices_array.size
            if total_elements > 0 and d > 0 and total_elements % d == 0:
                try:
                    temp_reshaped = indices_array.reshape(-1, d)
                    indices_input_processed = temp_reshaped
                    num_samples_in_batch = indices_input_processed.shape[0]
                    # print(f"Debug (physical_point): Reshaped unexpected shape {indices_array.shape} to {indices_input_processed.shape}")
                except Exception as e:
                     print(f"Error reshaping indices (unexpected shape {indices_array.shape}) in physical_point_index_fun: {e}. Skipping batch processing and collection.")
                     # Cannot process indices correctly. Return empty array and no collection.
                     return np.array([]) # Return empty 1D array.

            else:
                 print(f"Error: Cannot process indices with unexpected shape {indices_array.shape} and size {total_elements} in physical_point_index_fun. Skipping batch processing and collection.")
                 return np.array([]) # Return empty 1D array


    else:
         print(f"Error: Unexpected indices type: {type(indices_array)} in physical_point_index_fun. Skipping batch processing and collection.")
         return np.array([]) # Return empty 1D array


    if num_samples_in_batch <= 0:
         # print(f"Debug: physical_point_index_fun received batch with 0 samples.")
         return np.array([]) # Return empty 1D array


    # --- Mapping to Physical Points ---
    # Use map_to_chebyshev_nodes with the number of nodes per dimension (MODAL_SIZE)
    # Assumes indices_input_processed is (num_samples, d)
    try:
        physical_points = map_to_chebyshev_nodes(indices_input_processed, MODAL_SIZE)
    except ValueError as e:
         print(f"Error during map_to_chebyshev_nodes in physical_point_index_fun: {e}. Skipping batch processing and collection.")
         return np.array([]) # Return empty 1D array
    except Exception as e:
         print(f"Unexpected error during map_to_chebyshev_nodes in physical_point_index_fun: {e}. Skipping batch processing and collection.")
         return np.array([]) # Return empty 1D array


    # --- Call Actual Function ---
    # Call custom_function with the physical points. This increments total_vector_evaluations.
    # custom_function expects (num_samples, d) and returns (num_samples,)
    try:
        values = custom_function(physical_points)
        # Ensure values is a numpy array of shape (num_samples,)
        values = np.asarray(values) # Ensure it's a numpy array first
        if values.ndim > 1: # Flatten if it's not 1D, e.g., (num_samples, 1)
             values = values.flatten()

        if len(values) != num_samples_in_batch:
             # This indicates custom_function returned a non-matching number of values.
             # This is a critical issue for pairing.
             print(f"CRITICAL Mismatch: custom_function returned {len(values)} values for {num_samples_in_batch} input points. Expected {num_samples_in_batch}. Skipping batch collection for this batch.")
             # We cannot reliably pair indices and values. Return empty array.
             return np.array([])

    except Exception as e:
         print(f"Error during custom_function call in physical_point_index_fun: {e}. Skipping batch collection for this batch.")
         # If custom_function fails, we can't get values. Skip collection.
         return np.array([]) # Return empty 1D array


    # --- Collection happens HERE ---
    # We have indices_input_processed ((num_samples, d)) and values ((num_samples,))
    # Both should have length num_samples_in_batch at this point if no errors occurred above.

    # Convert indices to list of lists of ints
    indices_to_collect_list = [list(row.astype(int)) for row in indices_input_processed]

    # Convert values to list of scalars
    values_to_collect_list = values.tolist() # 'values' is already a 1D numpy array here

    # Final check for consistency before extending global lists (should pass if no errors above)
    if len(indices_to_collect_list) != len(values_to_collect_list):
         # This indicates a logic error before this point, or a very strange custom_function output
         print(f"CRITICAL Logic Error: Mismatch before global collection. Indices list length {len(indices_to_collect_list)}, Values list length {len(values_to_collect_list)}. Skipping batch collection.")
         # Do not extend global lists if this happens.
    else:
        # Normal case: counts match within the batch. Extend global lists.
        collected_indices.extend(indices_to_collect_list)
        collected_evaluations.extend(values_to_collect_list)
        # print(f"Debug: Successfully collected {len(indices_to_collect_list)} pairs from batch in physical_point_index_fun.")


    # Return the values array from custom_function to the calling function chain
    # The caller expects an array matching the number of input indices.
    return values


# Modificada create_tt_initial y get_tt_shape - ya se hicieron en pasos anteriores

# Helper function to get the shape of a ttml.TensorTrain
def get_tt_shape(tt):
    """
    Compute the shape of a ttml.TensorTrain by inspecting the dimensions of its cores.
    """
    shape = []
    if tt is None or not hasattr(tt, 'cores') or not isinstance(tt.cores, list):
        return None

    if not tt.cores:
        return tuple()

    for i, core in enumerate(tt.cores):
        if not isinstance(core, np.ndarray) or core.ndim != 3:
             print(f"Warning: Core {i} is not a 3D numpy array. Shape cannot be determined.")
             return None
        n_k = core.shape[1]
        shape.append(n_k)

    global d # Use global d
    if len(shape) != d:
        print(f"Warning: Inferred shape dimension ({len(shape)}) does not match global dimension d ({d}).")

    return tuple(shape)

def make_omega_set(n_nodes, size, d, random_state: np.random.RandomState):
    """
    Generate a set of random multi-indices with d components, each between 0 and n_nodes - 1.

    Parameters:
    - n_nodes (int): Upper bound for each index (range is 0 to n_nodes - 1).
    - size (int): Number of multi-indices to generate.
    - d (int): Number of dimensions.
    - random_state (np.random.RandomState): The random state object to use for reproducibility.

    Returns:
    - np.array: 2D array of shape (size, d) with random multi-indices.
    """
    if n_nodes < 1:
        raise ValueError("n_nodes must be positive")
    if size < 0:
        raise ValueError("size must be non-negative")

    total_combinations = n_nodes ** d
    if size > total_combinations:
        raise ValueError(f"Requested size ({size}) exceeds total possible combinations ({total_combinations})")

    # Generate random multi-indices (0 to n_nodes - 1) using the provided random_state
    Omega = random_state.randint(0, n_nodes, size=(size, d))

    return Omega

# Additional helper function needed for adaptive sampling
def generate_disjoint_omega_c(Omega, sizeOmega_C, number_nodes, random_state: np.random.RandomState):
    """
    Generate a disjoint Omega_C set that does not overlap with Omega, using a provided random_state.

    Parameters:
    - Omega: Existing indices (training set), can be tf.Tensor or np.ndarray.
    - sizeOmega_C: Size of the new Omega_C set.
    - number_nodes: Number of nodes per dimension.
    - random_state (np.random.RandomState): The random state object to use for reproducibility.

    Returns:
    - np.array: New Omega_C set of shape (sizeOmega_C, d).
    """
    if isinstance(Omega, tf.Tensor):
        Omega = Omega.numpy()  # Convert to NumPy if it's a tensor

    # Handle the case where Omega is empty
    if Omega.size == 0:
        d = Omega.shape[1] if Omega.ndim > 1 else 0 # Get dimension from shape if not empty
        if d == 0: # If Omega is empty and has no defined dimension (e.g., np.empty((0,0)))
            if sizeOmega_C > 0: # If we need to generate new points, d must be inferable
                raise ValueError("Cannot infer dimension 'd' from an empty Omega set if sizeOmega_C > 0.")
            else: # If both are empty, return empty array
                return np.empty((0, 0), dtype=int)
    else:
        d = Omega.shape[1] # Dimension of the multi-indices

    total_points = number_nodes ** d

    if sizeOmega_C > total_points:
        raise ValueError(f"Requested size for Omega_C ({sizeOmega_C}) exceeds total possible unique points ({total_points}).")

    all_indices = set(tuple(idx) for idx in Omega)

    # Calculate how many points are available to choose from
    available_points_count = total_points - len(all_indices)
    if sizeOmega_C > available_points_count:
        raise ValueError(f"Requested size for Omega_C ({sizeOmega_C}) exceeds available unique points ({available_points_count}) not in Omega.")

    max_attempts_per_point = 100 # Max attempts to find a unique point
    max_total_attempts = max_attempts_per_point * sizeOmega_C

    Omega_C = []
    attempts = 0
    while len(Omega_C) < sizeOmega_C and attempts < max_total_attempts:
        # Generate candidate using the provided random_state
        candidate = random_state.randint(0, number_nodes, size=(d,))
        if tuple(candidate) not in all_indices:
            Omega_C.append(candidate)
            all_indices.add(tuple(candidate)) # Add to the set of seen indices to avoid future duplicates
        attempts += 1

    if len(Omega_C) < sizeOmega_C:
        raise ValueError(f"Could not generate enough unique indices for Omega_C. Requested: {sizeOmega_C}, Generated: {len(Omega_C)}. Consider increasing max_attempts_per_point or checking if enough disjoint points exist.")

    return np.array(Omega_C)

# Función para generar nodos de Chebyshev en el intervalo [0, 1]
def chebyshev_nodes(n):
    """
    Generate Chebyshev nodes of the first kind in the interval [0, 1].

    Parameters:
    - n (int): Number of intervals (number of nodes will be n + 1). Must be >= 0.

    Returns:
    - np.array: Array of (n + 1) Chebyshev nodes in [0, 1].
    """
    if n < 0:
         raise ValueError("Number of intervals n must be non-negative")
    if n == 0:
        return np.array([0.5]) # Convention for 1 point
    k = np.arange(n + 1)  # Indices 0 to n
    # Evitar división por cero si n es 0, aunque ya lo manejamos.
    # Para n>0, np.cos(np.pi * k / n) es seguro.
    q = np.cos(np.pi * k / n)  # Nodos en [-1, 1]
    x = (q + 1) / 2  # Mapeo a [0, 1]
    return x

# Función para mapear índices a nodos de Chebyshev
def map_to_chebyshev_nodes(Omega, n_nodes):
    """
    Map multi-indices from range [0, n_nodes - 1] to Chebyshev nodes in [0, 1] for each dimension.

    Parameters:
    - Omega: 2D NumPy array of shape (size, d) with multi-indices (integers 0 to n_nodes-1).
    - n_nodes: Total number of nodes per dimension (size of the mode). Must be >= 1.

    Returns:
    - 2D NumPy array of shape (size, d) with values mapped to Chebyshev nodes.
    """
    if n_nodes < 1:
        raise ValueError("n_nodes (total number of points) must be positive")

    # chebyshev_nodes(n_nodes - 1) generates n_nodes points
    num_intervals_for_cheb = n_nodes - 1

    cheb_nodes = chebyshev_nodes(num_intervals_for_cheb)

    # Map each index to the corresponding Chebyshev node
    # Assumes Omega is (size, d)
    global d # Use the global d
    if Omega.ndim != 2 or Omega.shape[1] != d:
         # This check is better placed here for map_to_chebyshev_nodes's expected input format
         raise ValueError(f"Input Omega to map_to_chebyshev_nodes must be 2D (size, d). Got shape {Omega.shape}")

    Omega_mapped = np.zeros_like(Omega, dtype=float)
    # Ensure indices are integers before using them for indexing
    indexed_omega = Omega.astype(int)

    # Check index range validity (optional but good for debugging)
    # if np.any(indexed_omega < 0) or np.any(indexed_omega >= n_nodes):
    #     min_idx = np.min(indexed_omega)
    #     max_idx = np.max(indexed_omega)
    #     print(f"Warning: Indices in Omega ({Omega.shape}) expected to be in range [0, {n_nodes - 1}] for n_nodes={n_nodes}. Found min: {min_idx}, max: {max_idx}")


    for dim in range(d):
        # Look up the physical point for each index in this dimension
        # This might fail if indexed_omega[:, dim] contains values outside [0, num_intervals_for_cheb]
        # The check above should catch this.
        Omega_mapped[:, dim] = cheb_nodes[indexed_omega[:, dim]]

    return Omega_mapped

def create_tt_initial(ranks, n, d, constant_value=1.0):
    """Crea un tt_initial con los rangos especificados y elementos constantes."""
    cores = []
    for k in range(d):
        if k == 0:
            cores.append(np.ones((1, n + 1, ranks[1])) * constant_value)
        elif k == d - 1:
            cores.append(np.ones((ranks[d-1], n + 1, 1)) * constant_value)
        else:
            cores.append(np.ones((ranks[k], n + 1, ranks[k+1])) * constant_value)
    return TensorTrain(cores)

def create_tt_random(ranks, n, d, seed=None):
    """Crea un tt_initial con los rangos especificados y elementos aleatorios entre 0 y 1."""
    cores = []
    rng = np.random.default_rng(seed)  # Utiliza el nuevo sistema de RNG para mayor control

    for k in range(d):
        if k == 0:
            cores.append(rng.random((1, n + 1, ranks[1])))
        elif k == d - 1:
            cores.append(rng.random((ranks[d - 1], n + 1, 1)))
        else:
            cores.append(rng.random((ranks[k], n + 1, ranks[k + 1])))
    return TensorTrain(cores)

def convertir_ttml_a_t3f(ttml_tensor):
    """
    Convierte un objeto ttml.TensorTrain a un objeto t3f.TensorTrain.

    Args:
        ttml_tensor: El objeto ttml.TensorTrain a convertir.

    Returns:
        t3f.TensorTrain or None: El objeto t3f.TensorTrain convertido,
                                 o None si la conversión falla o la entrada es None.
    """
    if ttml_tensor is None:
        print("Error: El objeto ttml.TensorTrain de entrada es None. No se puede convertir.")
        return None

    print("\n--- Proceso de Conversión a t3f ---")

    # Imprimir la forma del objeto ttml original usando la función proporcionada
    try:
        ttml_shape = get_tt_shape(ttml_tensor)
        print(f"Forma inferida del ttml.TensorTrain original (usando get_tt_shape): {ttml_shape}")
    except Exception as e:
        print(f"Error al obtener la forma del ttml.TensorTrain original (usando get_tt_shape): {e}")
        ttml_shape = None # Indicar que la forma no se pudo determinar

    # Obtener la lista de tensores core del objeto ttml.TensorTrain
    ttml_cores = ttml_tensor.cores

    # Crear un objeto t3f.TensorTrain usando la misma lista de cores.
    # El constructor de t3f.TensorTrain infiere la forma y los rangos
    # a partir de las formas de los cores proporcionados.
    # La forma (rank_in, size, rank_out) de los cores de ttml es compatible con t3f.
    try:
        # Intentar crear el objeto t3f.TensorTrain
        t3f_tt = t3f.TensorTrain(ttml_cores)

        print("Creación del objeto t3f.TensorTrain exitosa.")

        # Intentar imprimir las propiedades del objeto t3f.
        print("\nPropiedades del t3f.TensorTrain convertido:")

        # Usar .get_shape() (método) para obtener la forma del tensor completo
        print(f"  Shape (usando get_shape()): {t3f_tt.get_shape()}")

        # Mostrar los rangos TT
        print(f"  TT Rank (usando t3f.tt_ranks()): {t3f.tt_ranks(t3f_tt)}")
        # Puedes acceder a los cores de t3f_tt usando t3f_tt.cores

        return t3f_tt

    except Exception as e:
        # El mensaje de error modificado muestra el tipo de excepción y el mensaje original
        print(f"Error inesperado al crear o acceder a atributos del objeto t3f.TensorTrain: {type(e).__name__}: {e}")
        print("Esto podría indicar un problema con la instalación, la versión o el backend de t3f, o un problema subyacente con los cores de entrada.")
        print("Asegúrate de que t3f está correctamente instalado ('pip install t3f') y de que los cores de ttml tienen las formas esperadas.")
        return None

# --- Funciones auxiliares ---
def random_idx_t3f(tt_tensor, N, random_state: np.random.RandomState):
    """
    Generate N random indices for a t3f tensor train using a provided random_state.
    """
    dims = tt_tensor.shape # Usamos .shape en lugar de .get_shape().as_list() para compatibilidad general.
    # dims es una lista de [dim0, dim1, ..., dim_d-1]
    idx = np.stack([random_state.choice(dim, size=N) for dim in dims], axis=-1)
    return idx

# --- FUNCIÓN NUEVA: Generar el conjunto de validación (test set) ---
def generate_validation_set(
    sizeOmega_C: int,
    d: int,
    number_nodes: int,
    custom_function,
    seed: int
):
    """
    Genera el conjunto de validación (Omega_C y A_Omega_C) de manera reproducible.

    Args:
        sizeOmega_C (int): Número de puntos para el conjunto de validación.
        d (int): Dimensión del tensor.
        number_nodes (int): Número de nodos por dimensión.
        custom_function (callable): La función verdad fundamental.
        seed (int): La semilla para el generador de números aleatorios.

    Returns:
        tuple: (Omega_C_tf, A_Omega_C_tf) como tensores de TensorFlow.
    """
    rng = np.random.RandomState(seed)  # Generador de números aleatorios con la semilla

    # Generar Omega_C_indices aleatoriamente sin lógica de disjuntos
    Omega_C_indices_np = rng.randint(0, number_nodes, size=(sizeOmega_C, d))
    Omega_C_mapped_np = map_to_chebyshev_nodes(Omega_C_indices_np, number_nodes)
    A_Omega_C_np = custom_function(Omega_C_mapped_np).astype(np.float64)

    # Asegurarse de que A_Omega_C_np sea 1D y de tipo float64
    if A_Omega_C_np.ndim > 1:
        if A_Omega_C_np.shape[1] == 1:
            A_Omega_C_np = A_Omega_C_np.flatten().astype(np.float64)
        else:
            raise ValueError("custom_function debe devolver un array 1D o un array 2D con una columna para A_Omega_C")
    else:
        A_Omega_C_np = A_Omega_C_np.astype(np.float64)

    Omega_C_tf = tf.constant(Omega_C_indices_np, dtype=tf.int32)
    A_Omega_C_tf = tf.constant(A_Omega_C_np, dtype=tf.float64)

    return Omega_C_tf, A_Omega_C_tf

def process_and_verify(indices_list, evaluations_list, n_chebyshev_nodes):
    """
    Procesa las listas de índices y evaluaciones, elimina repeticiones emparejadas
    y verifica la correspondencia con la evaluación de la función en los nodos de Chebyshev.

    Args:
        indices_list: Lista de listas o array-like de multi-índices.
        evaluations_list: Lista o array-like de valores de evaluación.

    Returns:
        Una tupla conteniendo:
            - final_indices: Array NumPy de los multi-índices únicos.
            - final_evals: Array NumPy de los valores de evaluación correspondientes.
    """
    # Convertir a los tipos necesarios
    best_indices_list = np.array(indices_list, dtype=int)
    best_evaluations_list = np.array(evaluations_list, dtype=np.float64)

    # Eliminar repeticiones emparejadas
    final_indices, final_evals, counts = remove_paired_repetitions(best_indices_list, best_evaluations_list)

    if final_indices is not None:
        # Verificar la correspondencia con la evaluación en los nodos de Chebyshev
        verify_evaluation_with_mapping(final_indices, final_evals, n_chebyshev_nodes)
        return final_indices, final_evals
    else:
        return None, None

def remove_paired_repetitions(indices_list, evaluations_list):
    """
    Elimina las repeticiones de ambos arrays basándose en las repeticiones encontradas en indices_list,
    manteniendo la primera ocurrencia de cada elemento único.
    """
    if len(indices_list) != len(evaluations_list):
        print("¡Las listas tienen longitudes diferentes, no se puede procesar!")
        return None, None, None

    initial_indices_count = len(indices_list)
    initial_evals_count = len(evaluations_list)

    unique_indices, first_occurrence_indices = np.unique(indices_list, axis=0, return_index=True)
    sorted_first_occurrence_indices = np.sort(first_occurrence_indices)

    final_indices_list = indices_list[sorted_first_occurrence_indices]
    final_evaluations_list = evaluations_list[sorted_first_occurrence_indices]

    final_indices_count = len(final_indices_list)
    final_evals_count = len(final_evaluations_list)

    print("Conteo de elementos antes de eliminar repeticiones:")
    print(f"  best_indices_list: {initial_indices_count}")
    print(f"  best_evaluations_list: {initial_evals_count}")

    print("\nConteo de elementos después de eliminar repeticiones:")
    print(f"  best_indices_list: {final_indices_count}")
    print(f"  best_evaluations_list: {final_evals_count}")

    if final_indices_count == final_evals_count:
        print("\nEl número de elementos coincide en ambas listas después de eliminar repeticiones.")
    else:
        print("\n¡El número de elementos NO coincide en ambas listas después de eliminar repeticiones!")

    return final_indices_list, final_evaluations_list, (initial_indices_count, initial_evals_count, final_indices_count, final_evals_count)

def verify_evaluation_with_mapping(final_indices, final_evals, n_nodes=20): # Este valor esta puesto por defect
    """
    Comprueba si el i-ésimo elemento de final_evals corresponde a la
    evaluación de custom_function en el mapeo del i-ésimo elemento de final_indices
    a los nodos de Chebyshev.
    """
    if final_indices is None or final_evals is None or len(final_indices) != len(final_evals):
        print("¡Las listas finales no son válidas o tienen longitudes diferentes!")
        return

    mismatched_count = 0
    for i in range(len(final_indices)):
        mapped_vector = map_to_chebyshev_nodes(final_indices[i].reshape(1, -1), n_nodes)
        evaluated_value = custom_function(mapped_vector[0])
        expected_value = final_evals[i]

        if not np.allclose(evaluated_value, expected_value):
            mismatched_count += 1
            print(f"¡Desajuste en el índice {i}!")
            print(f"  Vector de índice: {final_indices[i]}")
            print(f"  Vector mapeado: {mapped_vector[0]}")
            print(f"  Valor esperado: {expected_value}")
            print(f"  Valor calculado: {evaluated_value}")

    if mismatched_count == 0:
        print("¡Todos los elementos de final_evals corresponden a la evaluación de custom_function en los nodos de Chebyshev mapeados desde final_indices!")
    else:
        print(f"\nSe encontraron {mismatched_count} desajustes entre final_evals y la evaluación de los nodos de Chebyshev mapeados desde final_indices.")

# Función optimizada para submuestrear best_indices_list usando selección greedy en GPU
def subsample_indices(indices, subsample_size, d, initial_random=2000, batch_size=200):
    """
    Submuestrea un conjunto de índices usando selección aleatoria inicial y greedy en GPU con lotes.

    Parameters:
    - indices: np.array con los índices a submuestrear (array 2D).
    - subsample_size: número de índices a seleccionar.
    - d: número de dimensiones de los índices.
    - initial_random: número de índices a seleccionar aleatoriamente al inicio.
    - batch_size: número de índices a seleccionar por iteración greedy.

    Returns:
    - np.array con el subconjunto seleccionado.
    """
    n = len(indices)
    if subsample_size >= n:
        print(f"Advertencia: subsample_size ({subsample_size}) es mayor o igual al número de índices ({n}). Devolviendo todos los índices.")
        return indices

    # Seleccionar un subconjunto aleatorio inicial
    initial_random = min(initial_random, subsample_size, n)
    initial_selected = np.random.choice(n, initial_random, replace=False)
    selected = list(initial_selected)
    remaining = np.setdiff1d(np.arange(n), selected)

    indices_tf = tf.constant(indices, dtype=tf.int32)
    total_iterations = (subsample_size - len(selected) + batch_size - 1) // batch_size
    print(f"Iniciando submuestreo: {len(selected)}/{subsample_size} índices seleccionados inicialmente. "
          f"Total de iteraciones greedy: {total_iterations}")

    # Añadir índices greedy por lotes
    for i in range(total_iterations):
        batch = min(batch_size, subsample_size - len(selected))
        if batch <= 0 or len(remaining) == 0:
            print("No hay más índices restantes o se alcanzó el tamaño deseado.")
            break

        print(f"Iteración {i+1}/{total_iterations}: Calculando distancias para {len(remaining)} índices restantes...")
        # Convertir índices a tensores
        remaining_tf = tf.constant(indices[remaining], dtype=tf.int32)
        selected_tf = tf.constant(indices[selected], dtype=tf.int32)

        # Calcular distancias de Hamming en GPU
        remaining_exp = tf.expand_dims(remaining_tf, 1)  # (n_rem, 1, d)
        selected_exp = tf.expand_dims(selected_tf, 0)  # (1, n_sel, d)
        dists = tf.reduce_mean(tf.cast(remaining_exp != selected_exp, tf.float32), axis=2)  # (n_rem, n_sel)
        min_dists = tf.reduce_min(dists, axis=1)  # (n_rem,)

        # Seleccionar los 'batch' índices con mayor distancia mínima
        _, top_indices = tf.math.top_k(min_dists, k=batch)
        top_indices = top_indices.numpy()

        selected.extend(remaining[top_indices])
        remaining = np.delete(remaining, top_indices)

        print(f"Progreso: {len(selected)}/{subsample_size} índices seleccionados "
              f"({(len(selected)/subsample_size)*100:.1f}%)")

    return indices[selected]

def convertir_t3f_a_ttml(t3f_tensor):
    """
    Convierte un objeto t3f.TensorTrain a un objeto ttml.TensorTrain.

    Args:
        t3f_tensor: El objeto t3f.TensorTrain a convertir.

    Returns:
        ttml.TensorTrain or None: El objeto ttml.TensorTrain convertido,
                                  o None si la conversión falla o la entrada es None.
    """
    # Verificar si la entrada es válida
    if t3f_tensor is None:
        print("Error: El objeto t3f.TensorTrain de entrada es None. No se puede convertir.")
        return None

    print("\n--- Proceso de Conversión de t3f a ttml ---")

    # Extraer los cores del objeto t3f.TensorTrain
    t3f_cores = t3f_tensor.tt_cores

    # Convertir los tensores de TensorFlow a arrays de NumPy
    ttml_cores = [core.numpy() for core in t3f_cores]

    # Crear un nuevo objeto ttml.TensorTrain con los cores convertidos
    try:
        ttml_tt = TensorTrain(cores=ttml_cores, mode='l', is_orth=False)
        print("Creación del objeto ttml.TensorTrain exitosa.")

        # Verificación opcional: Imprimir propiedades
        print("\nPropiedades del ttml.TensorTrain convertido:")
        print(f"  Dims: {ttml_tt.dims}")
        print(f"  TT Rank: {ttml_tt.tt_rank}")

        return ttml_tt

    except Exception as e:
        print(f"Error al crear el objeto ttml.TensorTrain: {e}")
        return None

import numpy as np
import tensorflow as tf
import t3f

# Assuming ttml.tensor_train.TensorTrain is available for type checking
try:
    from ttml.tensor_train import TensorTrain as TTMLTensorTrain
except ImportError:
    # Define a dummy class if ttml is not installed, to prevent errors
    # if you only want to use the t3f path or test the structure.
    class TTMLTensorTrain:
        pass
    print("Warning: ttml.TensorTrain not found. The TTML path in the functions might not work correctly.")

def _calculate_tt_approximation_error_value(
    tt_approx_model,
    test_indices_raw, # Could be tf.Tensor or np.ndarray
    test_true_values_raw, # Could be tf.Tensor or np.ndarray
    metric: str = 'relative_l2'
) -> float:
    """
    Calculates the error of a TT approximation against true values for specified indices.
    This function performs the calculation without any printing. It supports ttml.TensorTrain
    and t3f.TensorTrain models.

    Args:
        tt_approx_model: The ttml.TensorTrain or t3f.TensorTrain approximation model.
        test_indices_raw: The multi-indices of the test set. Can be tf.Tensor or np.ndarray.
        test_true_values_raw: The true values corresponding to test_indices. Can be tf.Tensor or np.ndarray.
        metric (str): The error metric to use ('RMSE', 'relative_l1', 'relative_l2').

    Returns:
        float: The calculated overall error value.
    """
    # Convert true values to NumPy for consistent error calculation
    if tf.is_tensor(test_true_values_raw):
        true_values = test_true_values_raw.numpy()
    else:
        true_values = test_true_values_raw

    # Determine the type of the TT model and prepare indices accordingly
    estimated_values = None
    if isinstance(tt_approx_model, TTMLTensorTrain): # Check for ttml.TensorTrain
        # ttml.gather expects NumPy array indices
        if tf.is_tensor(test_indices_raw):
            test_indices_np = test_indices_raw.numpy()
        else:
            test_indices_np = test_indices_raw
        estimated_values = tt_approx_model.gather(test_indices_np)
        #print("Using ttml.TensorTrain.gather()")
    elif isinstance(tt_approx_model, t3f.TensorTrain): # Check for t3f.TensorTrain
        # t3f.gather_nd expects TensorFlow Tensor indices
        if not tf.is_tensor(test_indices_raw):
            test_indices_tf = tf.constant(test_indices_raw, dtype=tf.int32)
        else:
            test_indices_tf = tf.cast(test_indices_raw, dtype=tf.int32) # Ensure int32

        # t3f.gather_nd returns a tf.Tensor, convert to numpy for calculations
        estimated_values = t3f.gather_nd(tt_approx_model, test_indices_tf).numpy()
        #print("Using t3f.gather_nd()")
    else:
        raise TypeError("Unsupported TT approximation model type. Must be ttml.TensorTrain or t3f.TensorTrain.")

    # Ensure consistent shapes
    if estimated_values.shape != true_values.shape:
        if true_values.ndim > 1 and true_values.shape[1] == 1:
            true_values = true_values.flatten()
        elif estimated_values.ndim > 1 and estimated_values.shape[1] == 1:
            estimated_values = estimated_values.flatten()

        if estimated_values.shape != true_values.shape:
            # Handle shape mismatch by truncating to the minimum length
            print(f"Warning: Shape mismatch detected for error calculation. Truncating to min length. Estimated: {estimated_values.shape}, True: {true_values.shape}")
            min_len = min(len(estimated_values), len(true_values))
            estimated_values = estimated_values[:min_len]
            true_values = true_values[:min_len]

    # Calculate error based on metric
    overall_error = np.inf
    if metric == 'RMSE':
        squared_errors = (estimated_values - true_values) ** 2
        overall_error = np.sqrt(np.mean(squared_errors))
    elif metric == 'relative_l1':
        error_values = np.abs(estimated_values - true_values)
        denominator = np.maximum(np.abs(true_values), 1e-12)
        relative_errors = error_values / denominator
        overall_error = np.mean(relative_errors)
    elif metric == 'relative_l2':
        overall_error = calculate_relative_l2_error(true_values, estimated_values)

    return overall_error

# Definir códigos de color ANSI (repetidos aquí para asegurar que estén disponibles en el scope de _print_error_metrics)
COLOR_BLUE = "\033[94m"
COLOR_GREEN = "\033[92m"
COLOR_RESET = "\033[0m"


def _print_error_metrics(
    best_tt_approx_model, # Renamed to be generic
    indices,
    true_values,
    metric: str,
    data_set_name: str,
    show_example_values: bool
):
    """
    Helper function to calculate and print error metrics for a given dataset.
    Includes an option to show example values. Supports ttml.TensorTrain and t3f.TensorTrain.
    """
    # Convert inputs to NumPy if they are tf.Tensor for consistent handling in printing
    if tf.is_tensor(indices):
        indices_np = indices.numpy()
    else:
        indices_np = indices

    if tf.is_tensor(true_values):
        true_values_np = true_values.numpy()
    else:
        true_values_np = true_values

    # Get the overall error using the helper function (which now handles model type)
    overall_error = _calculate_tt_approximation_error_value(
        best_tt_approx_model, indices_np, true_values_np, metric # Pass raw (converted to numpy) indices here for _calculate
    )

    # Recalculate estimated_values for printing examples and percentiles based on model type
    estimated_values_for_print = None
    if isinstance(best_tt_approx_model, TTMLTensorTrain):
        estimated_values_for_print = best_tt_approx_model.gather(indices_np)
    elif isinstance(best_tt_approx_model, t3f.TensorTrain):
        # For t3f, ensure indices are tf.Tensor
        if not tf.is_tensor(indices):
            indices_tf = tf.constant(indices, dtype=tf.int32)
        else:
            indices_tf = tf.cast(indices, dtype=tf.int32) # Ensure int32
        estimated_values_for_print = t3f.gather_nd(best_tt_approx_model, indices_tf).numpy()
    else:
        raise TypeError("Unsupported TT approximation model type for printing. Must be ttml.TensorTrain or t3f.TensorTrain.")

    error_values_for_print = np.abs(estimated_values_for_print - true_values_np)

    # Ensure consistent shapes for printing
    if estimated_values_for_print.shape != true_values_np.shape:
        if true_values_np.ndim > 1 and true_values_np.shape[1] == 1:
            true_values_np = true_values_np.flatten()
        elif estimated_values_for_print.ndim > 1 and estimated_values_for_print.shape[1] == 1:
            estimated_values_for_print = estimated_values_for_print.flatten()

        if estimated_values_for_print.shape != true_values_np.shape:
            print(f"Warning: Shapes of estimated values for printing ({estimated_values_for_print.shape}) and true values ({true_values_np.shape}) do not match for {data_set_name}. This might affect example prints.")
            min_len = min(len(estimated_values_for_print), len(true_values_np))
            estimated_values_for_print = estimated_values_for_print[:min_len]
            true_values_np = true_values_np[:min_len]
            error_values_for_print = error_values_for_print[:min_len]

    metric_name = "RMSE (Root Mean Squared Error)" if metric == 'RMSE' else \
                  "Relative L1 Error" if metric == 'relative_l1' else \
                  "Relative L2 Error (Norm)"

    print(f"\n--- Results for {data_set_name} ({len(indices_np)} points) ---")

    # --- Print individual results (if show_example_values is True) ---
    if show_example_values:
        print(f"\nExample results from {data_set_name}:")
        for i in range(min(10, len(indices_np))):
            print(f"Index: {indices_np[i]}, "
                  f"Estimated: {format_value_dynamically(estimated_values_for_print[i])}, "
                  f"Actual: {format_value_dynamically(true_values_np[i])}, "
                  f"Absolute Error: {format_value_dynamically(error_values_for_print[i])}")

    # --- Aggregate statistics ---
    print(f"\nSummary of {metric_name} for {data_set_name}:")

    # Determine color based on data_set_name
    color_code = COLOR_BLUE if "Training Set" in data_set_name else COLOR_GREEN

    if metric == 'RMSE':
        print(f"{color_code}Total {metric_name}: {format_value_dynamically(overall_error)}{COLOR_RESET}")
        print("(Note: RMSE is a global metric; per-point error percentiles are not typically calculated for RMSE)")
    elif metric == 'relative_l2':
        print(f"{color_code}Total {metric_name}: {overall_error:.10e}{COLOR_RESET}")
        print("(Note: Relative L2 Error is a global norm metric; per-point error percentiles are not typically calculated for this value)")
    elif metric == 'relative_l1':
        denominator = np.maximum(np.abs(true_values_np), 1e-12)
        individual_errors_to_print = error_values_for_print / denominator
        print(f"{color_code}Max {metric_name} per point: {format_value_dynamically(np.max(individual_errors_to_print))}{COLOR_RESET}")
        print(f"{color_code}Mean {metric_name} per point: {format_value_dynamically(np.mean(individual_errors_to_print))}{COLOR_RESET}")
        print(f"{color_code}Median {metric_name} per point: {format_value_dynamically(np.median(individual_errors_to_print))}{COLOR_RESET}")
        print(f"{color_code}90th Percentile of {metric_name} per point: {format_value_dynamically(np.percentile(individual_errors_to_print, 90))}{COLOR_RESET}")
        print(f"{color_code}99th Percentile of {metric_name} per point: {format_value_dynamically(np.percentile(individual_errors_to_print, 99))}{COLOR_RESET}")

def check_approximation_accuracy(
    best_tt_approx_model, # Renamed to be generic
    Omega,    # Training indices
    A_Omega,    # Training true values
    test_indices,
    test_true_values,
    metric: str = 'relative_l2',
    show_example_values: bool = False
):
    """
    Compares the Tensor Train approximation against the ground truth
    using both the training and test set points, and calculates and prints the error.
    Automatically handles conversion of input indices and values from tf.Tensor to np.ndarray.
    This function supports both ttml.TensorTrain and t3f.TensorTrain models.

    Args:
        best_tt_approx_model: The ttml.TensorTrain or t3f.TensorTrain approximation model.
        Omega: The multi-indices of the training set. Can be tf.Tensor or np.ndarray.
        A_Omega: The true values corresponding to Omega. Can be tf.Tensor or np.ndarray.
        test_indices: The multi-indices of the test set (Omega_C). Can be tf.Tensor or np.ndarray.
        test_true_values: The true values corresponding to test_indices (A_Omega_C). Can be tf.Tensor or np.ndarray.
        metric (str, optional): The error metric to use.
                                 Possible values: 'RMSE', 'relative_l1', 'relative_l2'.
                                 Defaults to 'relative_l2'.
        show_example_values (bool, optional): If True, prints a few example estimated vs. actual values
                                               and their absolute errors. Defaults to False.

    Prints:
        Errors and summary statistics for both training and test sets, including percentiles.
    """
    # Initial check for valid metric
    if metric not in ['RMSE', 'relative_l1', 'relative_l2']:
        raise ValueError(f"Unsupported metric: {metric}. Choose from 'RMSE', 'relative_l1', or 'relative_l2'.")

    # Initial check for supported model type
    if not isinstance(best_tt_approx_model, (TTMLTensorTrain, t3f.TensorTrain)):
        raise TypeError("Unsupported TT approximation model type. 'best_tt_approx_model' must be an instance of ttml.TensorTrain or t3f.TensorTrain.")

    print(f"\n--- Checking Approximation Accuracy for Model Type: {type(best_tt_approx_model).__name__} ---")

    # Calculate and print metrics for the Training Set
    _print_error_metrics(
        best_tt_approx_model,
        Omega,
        A_Omega,
        metric,
        "Training Set",
        show_example_values
    )

    # Calculate and print metrics for the Test Set
    _print_error_metrics(
        best_tt_approx_model,
        test_indices,
        test_true_values,
        metric,
        "Test Set",
        show_example_values
    )

# TT-Cross

In [26]:
# --- Explicación de la limitación ---
# En un entorno real con el paquete ttml instalado, podrías usar inspect.getsource
# de la siguiente manera para obtener el código fuente:
#
import ttml.tt_cross
import inspect

try:
    source_init = inspect.getsource(ttml.tt_cross._compute_multi_indices)
    print(source_init)
except AttributeError:
    print("No se pudo encontrar _init_tt_cross en ttml.tt_cross.")

#Debido a que ttml no está instalado aquí, proporcionaremos implementaciones conceptuales.

def _compute_multi_indices(ind, ind_old, direction):
    """
    Compute new multiindex from old multiindex and pairs of (alpha_{k-1},i_k) as
    described in Savistyanov-Oseledets. This guarantees a nested sequence of
    multiindices, and works for both the left and right indices.
    """
    r = ind.shape[1]
    if direction == "RL":
        dim_indices, previous_indices = ind
    else:
        previous_indices, dim_indices = ind

    if ind_old is None:
        return dim_indices.reshape(1, r)
    else:
        ind_new = np.zeros((len(ind_old) + 1, r), dtype=np.int32)
        if direction == "RL":
            ind_new[1:, :] = ind_old[:, previous_indices]
            ind_new[0, :] = dim_indices
        elif direction == "LR":
            ind_new[:-1, :] = ind_old[:, previous_indices]
            ind_new[-1, :] = dim_indices
        else:
            raise ValueError("Direction has to be 'LR' or 'RL'")
    return ind_new



In [14]:
# --- Explicación de la limitación ---
# En un entorno real con el paquete ttml instalado, podrías usar inspect.getsource
# de la siguiente manera para obtener el código fuente:
#
import ttml.tt_cross
import inspect

try:
    source_init = inspect.getsource(ttml.tt_cross.maxvol)
    print(source_init)
except AttributeError:
    print("No se pudo encontrar _init_tt_cross en ttml.tt_cross.")

#Debido a que ttml no está instalado aquí, proporcionaremos implementaciones conceptuales.

def maxvol(A, eps=1e-2, niters=100):
    """
    Quasi-max volume submatrix

    Initializes with pivot indices of LU decomposition, then greedily
    interchanges rows.
    """
    n, r = A.shape
    if n <= r:
        return np.arange(n)
    A, _ = scipy.linalg.qr(A, mode="economic")
    out = scipy.linalg.lapack.dgetrf(A)  # LU decomp

    _, P, _ = out
    ind = _piv_to_ind(P, n)[:r]

    sbm = A[ind[:r]]
    b = _right_solve(sbm, A)

    for _ in range(niters):
        i0, j0 = np.unravel_index(np.argmax(np.abs(b)), b.shape)
        mx0 = b[i0, j0]
        if np.abs(mx0) <= 1 + eps:
            break
        k = ind[j0]
        b += np.outer(b[:, j0], b[k, :] - b[i0, :]) / mx0
        ind[j0] = i0
    ind.sort()
    return ind



In [24]:
# --- Explicación de la limitación ---
# En un entorno real con el paquete ttml instalado, podrías usar inspect.getsource
# de la siguiente manera para obtener el código fuente:
#
import ttml.tt_cross
import inspect

try:
    source_init = inspect.getsource(ttml.tt_cross._maxvol_tensor)
    print(source_init)
except AttributeError:
    print("No se pudo encontrar _init_tt_cross en ttml.tt_cross.")

#Debido a que ttml no está instalado aquí, proporcionaremos implementaciones conceptuales.

def _maxvol_tensor(X, mu, transpose=False):
    """
    Matricize `X` with respect to mode `mu` and return maxvol submatrix and
    indices
    """
    permutation = np.concatenate(
        [np.arange(mu), np.arange(mu + 1, len(X.shape)), [mu]]
    )
    Y = X.transpose(permutation)
    Y = Y.reshape(-1, X.shape[mu])
    ind = maxvol(Y)
    R = Y[ind, :]
    ind = np.unravel_index(ind, X.shape[:mu] + X.shape[mu + 1 :])
    ind = np.stack(ind)
    if transpose:
        R = R.T
    return ind, R



In [25]:
# --- Explicación de la limitación ---
# En un entorno real con el paquete ttml instalado, podrías usar inspect.getsource
# de la siguiente manera para obtener el código fuente:
#
import ttml.tt_cross
import inspect

try:
    source_init = inspect.getsource(ttml.tt_cross._init_tt_cross)
    print(source_init)
except AttributeError:
    print("No se pudo encontrar _init_tt_cross en ttml.tt_cross.")

#Debido a que ttml no está instalado aquí, proporcionaremos implementaciones conceptuales.

def _init_tt_cross(tt):
    """Generate initial set of R-matrices and right-indices for TT-cross.

    Same for DMRG as for regular TT-cross. This version follows the paper
    instead of Matlab code."""
    tt.orthogonalize("l")
    nd = len(tt)
    P_mats = [None] * (nd + 1)
    P_mats[0] = np.array([[1]])
    P_mats[-1] = np.array([[1]])
    index_array = [None] * (nd + 1)
    R = np.array([[1]])

    for i in range(nd - 1, 0, -1):
        core = tt[i]
        core = np.einsum("ijk,kl->ijl", core, R)

        # RQ decomposition of core
        Q, R = _qr_tensor(core, 0, True)

        tt[i] = Q
        Q = np.einsum("ijk,kl", Q, P_mats[i + 1])

        # Max vol indices
        # ind = maxvol(core.T)
        ind, P = _maxvol_tensor(Q, 0, True)
        P_mats[i] = P

        # Compute new indices from previous and maxvol
        ind_new = _compute_multi_indices(ind, index_array[i + 1], "RL")
        index_array[i] = ind_new
    tt[0] = np.einsum("ijk,kl->ijl", tt[0], R)

    # tt is 

In [11]:
# --- Explicación de la limitación ---
# En un entorno real con el paquete ttml instalado, podrías usar inspect.getsource
# de la siguiente manera para obtener el código fuente:
#
import ttml.tt_cross
import inspect

try:
    source_init = inspect.getsource(ttml.tt_cross._sweep_step_regular)
    print("--- _init_tt_cross Source ---")
    print(source_init)
except AttributeError:
    print("No se pudo encontrar _init_tt_cross en ttml.tt_cross.")

#Debido a que ttml no está instalado aquí, proporcionaremos implementaciones conceptuales.

--- _init_tt_cross Source ---
def _sweep_step_regular(
    i, direction, tt, index_array, index_fun, Pmats, cache=None, verbose=False
):
    """
    Do one step of the DMRG TT-cross algorithm, sweeping in a specified
    direction.

    Parameters
    ----------
    i : int
        Left index of the supercore
    direction : str
        Either "LR" or "RL", corresponding to a sweep in the left-to-right
        and right-to-left direcition respectively
    tt : TensorTrain
        TensorTrain to be modified
    index_array : list[np.ndarray]
        list of left_indices and right_indices. At step `i`, `index_array[i+1]`
        will be modified
    index_fun : function
        Function mapping indices to function values to be used for fitting
    Pmats : list[np.ndarray]
        The list of matrices to be used to compute maxvol at step i. At step
        `i`, `Pmats[i+1]` will be modified.
    verbose: bool (default: False)
        Print convergence information every step.
    """

    

In [15]:
import numpy as np
import time
import tensorflow as tf
from ttml.tt_cross import random_idx, _init_tt_cross, _sweep_step_regular, _sweep_step_dmrg, index_function_wrapper # Assuming _sweep_step_dmrg exists
from ttml.tensor_train import TensorTrain # Assuming this is your TensorTrain class

def collecting_index_fun(original_index_fun):
    """Wrapper que simplemente pasa los índices a la función original y retorna su resultado.
       LA RECOLECCIÓN NO OCURRE AQUÍ."""
    def wrapped_index_fun(indices):
        # No global collection happens in this function anymore.
        # Just pass through indices and return original_index_fun's result.
        # Let original_index_fun (which calls physical_point_index_fun) handle input shape and collection.
        return original_index_fun(indices)

    return wrapped_index_fun

# --- Merged tt_cross_regular_v2 (uses _sweep_step_regular) ---
def tt_cross_regular_v2(
    tt,
    index_fun,
    Omega_C,  # Índices del test set
    A_Omega_C,  # Valores verdaderos del test set
    tol_flattening: float = 1e-2,
    tol_precision: float = 1e-4,
    metric: str = 'relative_l2',  # 'rmse' o 'relative_l2'
    max_its: int = 10,
    verbose: bool = False,
    inplace: bool = True
):
    """
    Implements TT-Cross algorithm with selectable precision stopping criterion (RMSE or Relative L2).
    Uses an external validation set provided as Omega_C and A_Omega_C.
    """
    if metric not in ['rmse', 'relative_l2']:
        raise ValueError(f"Unsupported metric: {metric}. Choose 'rmse' or 'relative_l2'.")
    error_metric_name = "RMSE" if metric == 'rmse' else "Relative L2 Error"

    if not inplace:
        tt = tt.copy()
    tt, Pmats, index_array = _init_tt_cross(tt)
    direction = "LR"

    # Usar el conjunto de validación proporcionado
    cache = {
        "inds": Omega_C,
        "func_vals": A_Omega_C
    }

    errors = []
    stop_reason = "max_iterations"
    iterations_completed = 0

    for j in range(max_its):
        if verbose:
            print(f"Sweep {j}, direction {'LR' if direction == 'LR' else 'RL'}. Algorithm: tt_cross_regular_v2")

        if direction == "LR":
            for i in range(len(tt)):
                _sweep_step_regular(
                    i, "LR", tt, index_array, index_fun, Pmats, cache, verbose
                )
            direction = "RL"
        else:
            for i in range(len(tt) - 1, -1, -1):
                _sweep_step_regular(
                    i, "RL", tt, index_array, index_fun, Pmats, cache, verbose
                )
            direction = "LR"
        iterations_completed = j + 1

        # Calcular el error usando el conjunto de validación
        y_pred = tt.gather(cache["inds"])
        y_true = cache["func_vals"]

        current_error = np.inf
        if y_true.size > 0 and y_pred.size == y_true.size:
            if metric == 'rmse':
                current_error = np.sqrt(np.mean((y_pred - y_true) ** 2))
            elif metric == 'relative_l2':
                current_error = calculate_relative_l2_error(y_true, y_pred)
        errors.append(current_error)

        if verbose:
            print(f"Last {error_metric_name} (from validation set): {current_error:.10e}")

        if tol_precision is not None and current_error < tol_precision:
            stop_reason = "precision"
            break

        if tol_flattening is not None and len(errors) > 3:
            max_prev_error = np.max(errors[-4:-1])
            if max_prev_error > 1e-12 and max_prev_error != np.inf:
                change = (errors[-1] - max_prev_error) / max_prev_error
                if change > -tol_flattening:
                    stop_reason = "flattening"
                    break
            elif errors[-1] == max_prev_error:
                stop_reason = "flattening"
                break

    tt.orthogonalize()
    tt.errors = np.array(errors)
    return tt, stop_reason, iterations_completed, cache


# --- Modified tt_cross_regular_v2_dmrg_step (uses _sweep_step_dmrg) ---
def tt_cross_regular_v2_dmrg_step(
    tt,
    index_fun,
    tol_flattening: float = 1e-2,
    tol_precision: float = 1e-3,
    metric: str = 'relative_l2', # New: 'rmse' or 'relative_l2'
    max_its: int = 10,
    verbose: bool = False,
    inplace: bool = True,
    rank_kick: int = 0
):
    """
    Implements TT-Cross algorithm with DMRG sweep steps (_sweep_step_dmrg)
    and selectable precision/flattening stopping criteria (RMSE or Relative L2).
    """
    if metric not in ['rmse', 'relative_l2']:
        raise ValueError(f"Unsupported metric: {metric}. Choose 'rmse' or 'relative_l2'.")
    error_metric_name = "RMSE" if metric == 'rmse' else "Relative L2 Error"

    if not inplace:
        tt = tt.copy()

    tt, Pmats, index_array = _init_tt_cross(tt)
    direction = "LR"

    cache = dict()
    cache_inds_batch = random_idx(tt, 200)
    cache["inds"] = cache_inds_batch

    func_vals_initial = index_fun(cache_inds_batch)
    if func_vals_initial is not None and func_vals_initial.size > 0 :
        cache["func_vals"] = func_vals_initial.reshape(-1)
    else:
        cache["func_vals"] = np.array([]) # Start with empty if no values returned

    errors = []
    stop_reason = "max_iterations"
    iterations_completed = 0

    for j in range(max_its):
        if verbose:
            print(f"Sweep {j}, direction {'LR' if direction == 'LR' else 'RL'}. Algorithm: tt_cross_regular_v2_dmrg_step (Metric: {error_metric_name})")

        if direction == "LR":
            for i in range(len(tt) - 1): # DMRG sweep usually up to d-2
                _sweep_step_dmrg( # Uses DMRG sweep step
                    i, "LR", tt, index_array, index_fun, Pmats, rank_kick=rank_kick, verbose=verbose, cache=cache
                )
            direction = "RL"
        else:
            for i in range(len(tt) - 2, -1, -1): # DMRG sweep usually d-2 down to 0
                _sweep_step_dmrg( # Uses DMRG sweep step
                    i, "RL", tt, index_array, index_fun, Pmats, rank_kick=rank_kick, verbose=verbose, cache=cache
                )
            direction = "LR"
        iterations_completed = j + 1

        current_error_value = np.inf
        if cache["inds"].shape[0] > 0 and cache["func_vals"].size > 0:
            y_pred = tt.gather(cache["inds"])
            y_true = cache["func_vals"]

            min_len = min(len(y_pred), len(y_true))
            if min_len > 0:
                y_pred_m, y_true_m = y_pred[:min_len], y_true[:min_len]
                if metric == 'rmse':
                    current_error_value = np.sqrt(np.mean((y_pred_m - y_true_m) ** 2))
                elif metric == 'relative_l2':
                    current_error_value = calculate_relative_l2_error(y_true_m, y_pred_m)
                errors.append(current_error_value)
            else:
                errors.append(np.inf) # Cannot calculate error
        else:
            errors.append(np.inf) # Cache empty or incomplete

        current_error_value = errors[-1] # Get the error just calculated

        if verbose:
            print(f"Last {error_metric_name} (from cache): {current_error_value:.10e}")

        if tol_precision is not None and current_error_value < tol_precision:
            stop_reason = "precision"
            break

        if tol_flattening is not None and len(errors) > 3:
            max_prev_error = np.max(errors[-4:-1])
            if max_prev_error > 1e-12 and max_prev_error != np.inf: # Avoid division by zero/inf
                change = (current_error_value - max_prev_error) / max_prev_error
                if change > -tol_flattening:
                    stop_reason = "flattening"
                    break
            elif current_error_value == max_prev_error : # Handles cases where errors are tiny or stuck
                 stop_reason = "flattening"
                 break


    tt.orthogonalize()
    tt.errors = np.array(errors)
    return tt, stop_reason, iterations_completed, cache

# Global variables (assuming these are managed externally as in your original code)
total_vector_evaluations = 0
collected_indices = []
collected_evaluations = []

def optimize_tt_cross_rank_sweep(
    d,
    MODAL_SIZE,
    sizeOmega, # Límite de evaluaciones para la selección automática
    min_rank_to_try,
    max_rank_to_try,
    seed,
    tol_flattening,
    max_its,
    tol_precision,
    physical_point_index_fun,
    collecting_index_fun,
    tt_cross_algorithm_func,
    create_tt_random,
    create_tt_initial, # No se usa directamente en esta versión, pero se mantiene por consistencia de la firma
    get_tt_shape,
    TensorTrain, # Clase TensorTrain, si es necesaria para instanciar o type-hinting
    index_function_wrapper,
    Omega_C,    # Conjunto de validación (índices)
    A_Omega_C,    # Valores verdaderos del conjunto de validación
    # --- Funciones helper que deben estar definidas en el scope ---
    _calculate_tt_approximation_error_value, # Función para calcular error
    process_and_verify, # Función para procesar índices/evaluaciones
    # --- Parámetros adicionales ---
    dmrg_rank_kick: int = 0,
    metric: str = 'relative_l2',
    manual_rank: int = None,
):
    """
    Realiza un barrido de optimización TT-Cross sobre diferentes rangos utilizando un algoritmo TT-Cross
    especificado y una métrica de error. Utiliza un conjunto de validación externo (Omega_C, A_Omega_C).
    Al final, procesa los índices y evaluaciones recolectados para el mejor TT y calcula
    el error tanto en el conjunto de validación como en este conjunto de entrenamiento procesado.
    """
    global total_vector_evaluations, collected_indices, collected_evaluations

    # Definir códigos de color ANSI
    COLOR_BLUE = "\033[94m"
    COLOR_GREEN = "\033[92m"
    COLOR_RESET = "\033[0m"

    if metric not in ['rmse', 'relative_l2']:
        raise ValueError(f"Métrica no soportada: {metric}. Elija 'rmse' o 'relative_l2'.")
    error_metric_name = "RMSE" if metric == 'rmse' else "Error Relativo L2"

    NUM_INTERVALS = MODAL_SIZE - 1 # Asumo que esto es relevante para create_tt_random

    # Configuración de la función de indexación para TT-Cross
    chebyshev_grid_index_fun_base = lambda indices_batch: physical_point_index_fun(indices_batch)
    wrapped_by_ttml_index_fun = index_function_wrapper(chebyshev_grid_index_fun_base)
    complex_index_fun_wrapped_physical = collecting_index_fun(wrapped_by_ttml_index_fun)

    # Variables para la mejor aproximación automática
    auto_best_tt_approx = None
    auto_best_error_value_val = np.inf # Error en el conjunto de validación
    auto_best_evaluations_list = []
    auto_best_indices_list = []
    auto_best_rank_found = None
    auto_best_evaluations_count = 0

    # Variables para el rango manual
    manual_rank_data_tuple = None

    print(f"Iniciando barrido de rangos para TT-Cross (Métrica: {error_metric_name})")
    print(f"Validación externa: {Omega_C.shape[0] if hasattr(Omega_C, 'shape') else len(Omega_C)} puntos.")
    print("-" * 30)

    for current_rank_test in range(min_rank_to_try, max_rank_to_try + 1):
        print(f"\nProbando con rango TT objetivo: {current_rank_test}")
        desired_ranks = [1] + [current_rank_test] * (d - 1) + [1]

        # Crear TT inicial para esta prueba de rango
        tt_initial = create_tt_random(desired_ranks, NUM_INTERVALS, d, seed)

        print(f"  Rango TT inicial: {tt_initial.tt_rank}")
        initial_ttml_shape = get_tt_shape(tt_initial)
        print(f"  Forma TTML inicial: {initial_ttml_shape}")
        expected_shape = tuple([MODAL_SIZE] * d)
        if initial_ttml_shape != expected_shape:
            print(f"  Advertencia: La forma TTML inicial es {initial_ttml_shape}, pero se esperaba {expected_shape}.")

        # Reiniciar contadores de evaluación para esta ejecución de rango
        total_vector_evaluations = 0
        collected_indices = []  # Asegúrate de que tu collecting_index_fun las llene
        collected_evaluations = [] # Asegúrate de que tu collecting_index_fun las llene

        start_time = time.time()

        additional_kwargs = {}
        # Comprobar si la función tt_cross_algorithm_func acepta 'rank_kick'
        if hasattr(tt_cross_algorithm_func, '__code__') and 'rank_kick' in tt_cross_algorithm_func.__code__.co_varnames:
            additional_kwargs['rank_kick'] = dmrg_rank_kick

        tt_approx = None
        stop_reason = "Algoritmo no completado como se esperaba"
        iterations_completed = 0
        # run_cache ya no es necesaria si el error se recalcula siempre al final

        try:
            tt_approx, stop_reason, iterations_completed, _ = tt_cross_algorithm_func(
                tt=tt_initial,
                index_fun=complex_index_fun_wrapped_physical,
                Omega_C=Omega_C,      # Se pasa para el cálculo de error *interno* del algo TT-Cross
                A_Omega_C=A_Omega_C,    # Se pasa para el cálculo de error *interno* del algo TT-Cross
                tol_flattening=tol_flattening,
                tol_precision=tol_precision,
                metric=metric,        # Métrica para el criterio de parada *interno*
                max_its=max_its,
                verbose=False,        # Podrías hacerlo un parámetro de optimize_tt_cross_rank_sweep_new
                inplace=True,
                **additional_kwargs
            )
        except Exception as e:
            print(f"  Error durante tt_cross_algorithm_func para rango {current_rank_test}: {e}")
            # Continuar al siguiente rango si este falla

        end_time = time.time()
        duration = end_time - start_time
        evaluations_during_run = total_vector_evaluations # Desde la variable global modificada por collecting_index_fun
        collected_count = len(collected_evaluations)

        # --- RECALCULAR el error en el CONJUNTO DE VALIDACIÓN (Omega_C, A_Omega_C) ---
        # Se usa la métrica principal de la función (pasada como 'metric')
        current_run_error_value_validation = np.inf
        if tt_approx is not None:
            try:
                current_run_error_value_validation = _calculate_tt_approximation_error_value(
                    tt_approx_model=tt_approx,
                    test_indices_raw=Omega_C,
                    test_true_values_raw=A_Omega_C,
                    metric=metric # Usa la métrica principal definida para la optimización
                )
            except Exception as e:
                print(f"  Advertencia: Error al recalcular la métrica de validación final para el rango {current_rank_test}: {e}")
        else:
            print(f"  Advertencia: tt_approx es None para el rango {current_rank_test}. No se puede calcular el error de validación.")


        final_rank_of_approx = tt_approx.tt_rank if tt_approx is not None else "N/A (tt_approx is None)"
        print(f"  Rango TT final de la aproximación: {final_rank_of_approx}")
        print(f"  Evaluaciones totales de vectores (llamadas a physical_point_index_fun): {evaluations_during_run}")
        print(f"  Pares (índice, valor) recolectados individualmente: {collected_count}")
        print(f"  Iteraciones completadas: {iterations_completed}")
        print(f"  Razón de parada del algoritmo TT-Cross: {stop_reason}")
        print(f"  Tiempo de ejecución para este rango: {duration:.2f} segundos")
        # Imprimir el error de validación en verde
        print(f"  {COLOR_GREEN}{error_metric_name} (en Conjunto de Validación Omega_C, recalculado): {current_run_error_value_validation:.10e}{COLOR_RESET}")

        # Guardar datos si es el rango manual especificado
        if manual_rank is not None and current_rank_test == manual_rank:
            print(f"  Guardando datos para el rango especificado manualmente: {current_rank_test}.")
            manual_rank_data_tuple = (
                tt_approx.copy() if tt_approx is not None else None,
                current_run_error_value_validation, # Error en validación
                collected_evaluations.copy(),    # Los recolectados en esta ejecución
                collected_indices.copy(),        # Los recolectados en esta ejecución
                current_rank_test,
                evaluations_during_run
            )

        # Actualizar la mejor aproximación automática
        # Criterio: tt_approx no es None, evaluaciones dentro del límite Y error de validación es el mejor hasta ahora
        if tt_approx is not None and \
           evaluations_during_run <= sizeOmega and \
           current_run_error_value_validation < auto_best_error_value_val:

            auto_best_tt_approx = tt_approx.copy()
            auto_best_error_value_val = current_run_error_value_validation # Guardar el error de validación
            auto_best_evaluations_list = collected_evaluations.copy()
            auto_best_indices_list = collected_indices.copy()
            auto_best_rank_found = current_rank_test
            auto_best_evaluations_count = evaluations_during_run
            print(f"  ¡Nueva mejor aproximación automática encontrada!")
            print(f"    Rango: {current_rank_test}, {error_metric_name} (Validación): {auto_best_error_value_val:.10e}, Evaluaciones: {auto_best_evaluations_count}")

    # --- Lógica de Selección Final ---
    print(f"\n{'-'*30}\nSelección Final de la Aproximación TT")

    chosen_tt_approx = None
    chosen_error_value_val = np.inf # Error en el conjunto de validación para el TT elegido
    chosen_evaluations_list = []
    chosen_indices_list = []
    chosen_rank_found = None
    chosen_evaluations_count = 0
    selection_details_message = ""

    if manual_rank is not None:
        if manual_rank_data_tuple is not None:
            print(f"Priorizando resultados del rango especificado manualmente: {manual_rank}.")
            (chosen_tt_approx, chosen_error_value_val, chosen_evaluations_list,
             chosen_indices_list, chosen_rank_found, chosen_evaluations_count) = manual_rank_data_tuple
            selection_details_message = f"Rango seleccionado manualmente: {chosen_rank_found}"
            if chosen_tt_approx is None:
                selection_details_message += " (resultó en ninguna aproximación válida)"
            # Nota: la restricción de sizeOmega para el rango manual es una advertencia, no un descarte.
            elif chosen_evaluations_count > sizeOmega:
                print(f"  Advertencia: El rango manual {chosen_rank_found} usó {chosen_evaluations_count} evaluaciones, excediendo sizeOmega ({sizeOmega}).")
                selection_details_message += f" (Nota: usó {chosen_evaluations_count} evaluaciones, límite sizeOmega={sizeOmega})"
        else:
            print(f"Advertencia: Se especificó el rango manual {manual_rank} pero no se procesó o no está en el rango [{min_rank_to_try}, {max_rank_to_try}].")
            if auto_best_tt_approx is not None:
                print("  Recurriendo a la mejor aproximación encontrada automáticamente.")
                chosen_tt_approx = auto_best_tt_approx
                chosen_error_value_val = auto_best_error_value_val
                chosen_evaluations_list = auto_best_evaluations_list
                chosen_indices_list = auto_best_indices_list
                chosen_rank_found = auto_best_rank_found
                chosen_evaluations_count = auto_best_evaluations_count
                selection_details_message = f"Rango seleccionado automáticamente: {chosen_rank_found} (rango manual {manual_rank} no procesado)"
            else:
                selection_details_message = f"Rango manual {manual_rank} no procesado, y no se encontró ninguna aproximación automática."
    else: # Sin rango manual, usar la mejor automática si existe
        if auto_best_tt_approx is not None:
            print("Usando la mejor aproximación encontrada automáticamente (no se especificó rango manual).")
            chosen_tt_approx = auto_best_tt_approx
            chosen_error_value_val = auto_best_error_value_val
            chosen_evaluations_list = auto_best_evaluations_list
            chosen_indices_list = auto_best_indices_list
            chosen_rank_found = auto_best_rank_found
            chosen_evaluations_count = auto_best_evaluations_count
            selection_details_message = f"Rango seleccionado automáticamente: {chosen_rank_found}"
        else:
            selection_details_message = "No se especificó rango manual, y no se encontró ninguna aproximación automática."

    # --- Procesar los índices y evaluaciones del TT ELEGIDO ---
    Omega_processed_train = np.array([]) # Inicializar como arrays vacíos
    A_Omega_processed_train = np.array([]) # Inicializar como arrays vacíos
    if chosen_tt_approx is not None and chosen_indices_list and chosen_evaluations_list:
        print(f"\nProcesando los {len(chosen_indices_list)} pares (índice, valor) recolectados para el TT elegido...")
        try:
            Omega_processed_train, A_Omega_processed_train = process_and_verify(
                chosen_indices_list,
                chosen_evaluations_list,
                MODAL_SIZE # Pasamos MODAL_SIZE para posible uso en la verificación
            )
            print(f"  Procesamiento completado. Obtenido conjunto de entrenamiento con {Omega_processed_train.shape[0] if hasattr(Omega_processed_train, 'shape') else len(Omega_processed_train)} puntos.")
        except Exception as e:
            print(f"  Error durante process_and_verify: {e}. Omega_processed_train y A_Omega_processed_train permanecerán vacíos.")
            # Omega_processed_train y A_Omega_processed_train ya están inicializados como vacíos
    elif chosen_tt_approx is not None:
        print("\nAdvertencia: El TT elegido no tiene índices/evaluaciones recolectados para procesar (listas vacías).")
        # Omega_processed_train y A_Omega_processed_train ya están inicializados como vacíos

    # --- Resumen de la Aproximación Elegida y sus Errores ---
    print(f"\n{'-'*30}\nResumen de la Aproximación TT Elegida")
    if chosen_tt_approx is not None:

        print(f"Detalles de selección: {selection_details_message}")
        print(f"  Rango TT objetivo del barrido: {chosen_rank_found}")
        print(f"  Rangos TT reales de los cores: {chosen_tt_approx.tt_rank}")

        # Calcular error en el CONJUNTO DE ENTRENAMIENTO PROCESADO (Omega_processed_train, A_Omega_processed_train)
        error_on_processed_training_set = np.inf
        if Omega_processed_train.size > 0 and A_Omega_processed_train.size > 0:
            try:
                error_on_processed_training_set = _calculate_tt_approximation_error_value(
                    tt_approx_model=chosen_tt_approx,
                    test_indices_raw=Omega_processed_train,
                    test_true_values_raw=A_Omega_processed_train,
                    metric=metric
                )
                # Imprimir el error de entrenamiento en azul
                print(f"  {COLOR_BLUE}{error_metric_name} (en Conjunto de Entrenamiento Procesado): {error_on_processed_training_set:.10e}{COLOR_RESET}")
            except Exception as e:
                print(f"  Advertencia: Error al calcular la métrica en el conjunto de entrenamiento procesado: {e}")
        else:
            print(f"  {COLOR_BLUE}{error_metric_name} (en Conjunto de Entrenamiento Procesado): N/A (conjunto vacío o no generado).{COLOR_RESET}")


        # Imprimir el error de validación en verde
        print(f"  {COLOR_GREEN}{error_metric_name} (en Conjunto de Validación Omega_C): {chosen_error_value_val:.10e}{COLOR_RESET}")

        print(f"  Evaluaciones totales de vectores (función original): {chosen_evaluations_count}")
        print(f"  Pares (índice, valor) recolectados originalmente: {len(chosen_evaluations_list)}")
        print(f"    (Número de índices recolectados: {len(chosen_indices_list)})")
        print(f"    (Tamaño del conjunto de entrenamiento procesado: {Omega_processed_train.shape[0] if hasattr(Omega_processed_train, 'shape') else 0})")


        min_print_count = min(5, len(chosen_indices_list)) # Mostrar de la lista original recolectada
        if min_print_count > 0:
            print(f"\n  Primeros {min_print_count} pares (índice, valor) recolectados originalmente (ejemplo):")
            for i in range(min_print_count):
                idx = chosen_indices_list[i] # De la lista original
                val = chosen_evaluations_list[i] # De la lista original
                # Formatear el índice para impresión
                idx_list_str = [str(int(j)) for j in idx] if isinstance(idx, (list, np.ndarray, tuple)) else str(idx)
                print(f"    {i}: Índice={', '.join(idx_list_str)}, Evaluación={val:.8f}")
        else:
            print("\n  No se recolectaron pares (índice, valor) para la aproximación elegida.")

    else:
        print(f"\nNo se seleccionó ninguna aproximación TT adecuada. ({selection_details_message})")
        print(f"  (Rangos buscados de {min_rank_to_try} a {max_rank_to_try}. Límite de evaluaciones para selección auto: {sizeOmega})")

    print(f"{'-'*40}\nFin del barrido de rangos TT-Cross.\n{'-'*40}")

    return (chosen_tt_approx,
            chosen_error_value_val, # Error en el conjunto de validación para el TT elegido
            # chosen_evaluations_list, # Lista original de evaluaciones recolectadas
            # chosen_indices_list,     # Lista original de índices recolectados
            chosen_rank_found,
            chosen_evaluations_count,
            Omega_processed_train,   # Nuevo: conjunto de entrenamiento procesado (índices)
            A_Omega_processed_train) # Nuevo: conjunto de entrenamiento procesado (valores)

# Optimizadores

In [3]:
import numpy as np
import tensorflow as tf
import t3f

def optimize_tt_with_adam(X, A_Omega, Omega, A_Omega_C, Omega_C, d, number_nodes,
                          max_iters=7500, abs_loss_threshold=1e-6, improvement_threshold=1e-3,
                          patience=1000, # Patience for training loss stopping criterion
                          learning_rate_initial=0.01, decay_steps=400, decay_rate=0.7,
                          # Parámetros para ReduceLROnPlateau
                          reduce_lr_on_plateau=True, lr_patience=200, lr_factor=0.5,
                          lr_min_delta=1e-4, lr_min=1e-6, lr_monitor_interval=100,
                          reduce_lr_train_set=False, # NUEVO PARAMETRO: Monitorear entrenamiento para reducir LR
                          verbose=False, loss='relative_l2'):
    """
    Optimiza un tensor Train (TT) para completar las entradas faltantes utilizando el optimizador Adam.
    Incorpora Decaimiento Exponencial o Reducción de Tasa de Aprendizaje en Plateau.

    Args:
        X: Tensor TT inicial para la estimación.
        A_Omega: Tensor con los valores observados en las posiciones Omega (entrenamiento).
        Omega: Índices de las entradas observadas (entrenamiento).
        A_Omega_C: Tensor con los valores observados en las posiciones Omega_C (validación).
        Omega_C: Índices de las entradas observadas (validación).
        max_iters: Número máximo de iteraciones de optimización.
        abs_loss_threshold: Umbral absoluto para detener la optimización basado en la pérdida de entrenamiento (ahora RMSE si loss='mse').
        improvement_threshold: Umbral de mejora relativa para detener la optimización (entrenamiento, ahora RMSE si loss='mse').
        patience: Número de iteraciones a esperar para verificar la mejora en entrenamiento.
        learning_rate_initial: Tasa de aprendizaje inicial para el optimizador Adam.
        decay_steps: Número de pasos tras los cuales se aplica el decaimiento (si no se usa plateau).
        decay_rate: Factor de decaimiento (si no se usa plateau).
        reduce_lr_on_plateau: Si es True, usa ReduceLROnPlateau en lugar de decaimiento exponencial fijo.
        lr_patience: Número de iteraciones sin mejora en la métrica monitoreada para reducir la tasa de aprendizaje.
        lr_factor: Factor por el que se reduce la tasa de aprendizaje (nueva_lr = lr_actual * lr_factor).
        lr_min_delta: Umbral mínimo para considerar una mejoría en la métrica monitoreada.
        lr_min: Tasa de aprendizaje mínima.
        lr_monitor_interval: Cada cuántas iteraciones monitorear la métrica para ReduceLROnPlateau.
        reduce_lr_train_set (bool): Si True, monitorea la pérdida de entrenamiento para reducir LR.
                                     Si False, monitorea la pérdida de validación.
        verbose: Si es True, imprime información sobre el progreso de la optimización.
        loss: Tipo de función de pérdida ('mse' o 'relative_l2').

    Returns:
        Un tuple que contiene:
            - estimated: El tensor TT estimado después de la optimización.
            - loss_hist: Una lista con el historial de valores de la función de pérdida de entrenamiento (RMSE si loss='mse').
            - val_loss_hist: Una lista con el historial de valores de la función de pérdida de validación (RMSE si loss='mse').
    """
    estimated = t3f.get_variable('estimated', initializer=X)

    # Configurar el schedule de tasa de aprendizaje
    if reduce_lr_on_plateau:
        lr_variable = tf.Variable(learning_rate_initial, dtype=tf.float32)
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_initial)
        optimizer.learning_rate = lr_variable # Asignamos la Variable al atributo learning_rate del optimizador
    else:
        lr_schedule_obj = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=learning_rate_initial,
            decay_steps=decay_steps,
            decay_rate=decay_rate,
            staircase=False
        )
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule_obj)

    def calculate_loss(estimated_tensor, values, indices, metric_type='current_loss'):
        """
        Calcula diferentes tipos de pérdida o métricas.
        """
        estimated_vals = t3f.gather_nd(estimated_tensor, indices)
        diff = estimated_vals - values

        if metric_type == 'mse':
            return tf.reduce_mean(diff ** 2)
        elif metric_type == 'rmse':
            return tf.sqrt(tf.reduce_mean(diff ** 2))
        elif metric_type == 'relative_l2':
            norm_values = tf.norm(values)
            if tf.abs(norm_values) < tf.keras.backend.epsilon():
                return tf.norm(diff)
            return tf.norm(diff) / norm_values
        elif metric_type == 'current_loss':
            if loss.lower() == 'mse':
                return tf.reduce_mean(diff ** 2)
            elif loss.lower() == 'relative_l2':
                norm_values = tf.norm(values)
                if tf.abs(norm_values) < tf.keras.backend.epsilon():
                    return tf.norm(diff)
                return tf.norm(diff) / norm_values
            else:
                raise ValueError(f"Tipo de pérdida no válido: '{loss}'. Debe ser 'MSE' o 'relative_L2'.")
        else:
            raise ValueError(f"Tipo de métrica no válido: '{metric_type}'.")

    @tf.function
    def step_and_update(): # Renombrada para mayor claridad
        with tf.GradientTape() as tape:
            loss_train_for_grad = calculate_loss(estimated, A_Omega, Omega, metric_type='current_loss')

        gradients = tape.gradient(loss_train_for_grad, estimated.tt_cores)
        if any(g is None for g in gradients):
            tf.print("Warning: Found None gradients. Skipping gradient application.")
            # Si hay None gradients, devolvemos las pérdidas actuales para que el bucle continúe si se desea.
            # Estas pérdidas son para display y criterios de parada.
            loss_train_display = calculate_loss(estimated, A_Omega, Omega, metric_type='rmse' if loss.lower() == 'mse' else 'current_loss')
            loss_val_display = calculate_loss(estimated, A_Omega_C, Omega_C, metric_type='rmse' if loss.lower() == 'mse' else 'current_loss')
            return loss_train_display, loss_val_display

        optimizer.apply_gradients(zip(gradients, estimated.tt_cores))

        # Calcular pérdida de validación y entrenamiento para display y criterios de parada.
        # Estas pérdidas ya son POST-ACTUALIZACIÓN
        loss_train_display = calculate_loss(estimated, A_Omega, Omega, metric_type='rmse' if loss.lower() == 'mse' else 'current_loss')
        loss_val_display = calculate_loss(estimated, A_Omega_C, Omega_C, metric_type='rmse' if loss.lower() == 'mse' else 'current_loss')

        return loss_train_display, loss_val_display

    loss_hist = []
    val_loss_hist = []

    # Para ReduceLROnPlateau
    best_monitor_loss = float('inf')
    lr_wait = 0
    monitor_set_name = "Train" if reduce_lr_train_set else "Val"
    loss_metric_name = "RMSE" if loss.lower() == 'mse' else loss.upper()

    sizeOmega = tf.shape(Omega)[0]
    sizeOmega_C = tf.shape(Omega_C)[0]

    tt_rank = t3f.tt_ranks(X)

    print(f"Starting Adam Tensor Completion for a target tensor.")
    print(f"Dimensions: {d}, Nodes per dimension: {number_nodes}, Total size: {number_nodes**d}")
    print(f"Initial ranks: {tt_rank}")
    print(f"Training points: {sizeOmega}, Validation points: {sizeOmega_C}")
    print(f"Adam LR Initial: {learning_rate_initial:.4f}, Max Iters: {max_iters}")
    if reduce_lr_on_plateau:
        monitor_set = "Train" if reduce_lr_train_set else "Validation"
        print(f"Using ReduceLROnPlateau: Monitor={monitor_set} {loss_metric_name}, Patience={lr_patience}, Factor={lr_factor:.2f}, MinDelta={lr_min_delta:.1e}, MinLR={lr_min:.1e}, MonitorInterval={lr_monitor_interval}")
    else:
        print(f"Using Exponential Decay: DecaySteps={decay_steps}, DecayRate={decay_rate}")

    print("-" * 30)

    # --- Cálculo de la pérdida inicial (antes de la primera actualización) ---
    initial_loss_train_v = calculate_loss(estimated, A_Omega, Omega, metric_type='rmse' if loss.lower() == 'mse' else 'current_loss').numpy()
    initial_loss_val_v = calculate_loss(estimated, A_Omega_C, Omega_C, metric_type='rmse' if loss.lower() == 'mse' else 'current_loss').numpy()

    # Añadir al historial y mostrar el estado inicial
    loss_hist.append(initial_loss_train_v)
    val_loss_hist.append(initial_loss_val_v)

    if verbose:
        print(f"Estado Inicial (Iteración -1): Loss (Train - {loss_metric_name}) = {initial_loss_train_v:.6f}, "
              f"Loss (Val - {loss_metric_name}) = {initial_loss_val_v:.6f}, LR = {learning_rate_initial:.4f}")
    # Fin del cálculo y muestra de la pérdida inicial

    # El resto del código permanece igual, pero el bucle empezará a partir de la Iteración 0
    # Es importante que el `best_monitor_loss` para el RLROP se inicialice con la `initial_loss_val_v` o `initial_loss_train_v`
    # según lo que se esté monitoreando.
    best_monitor_loss = initial_loss_train_v if reduce_lr_train_set else initial_loss_val_v

    for i in range(max_iters):
        try:
            # loss_train_v y loss_val_v ahora siempre contienen la métrica para display (RMSE o relative_l2)
            # Y estas pérdidas son POST-ACTUALIZACIÓN de los cores para la iteración `i`
            loss_train_v, loss_val_v = step_and_update() # Renombrada para mayor claridad

            loss_hist.append(loss_train_v)
            val_loss_hist.append(loss_val_v)

            # Seleccionar la pérdida a monitorear para RLROP
            monitor_loss = loss_train_v if reduce_lr_train_set else loss_val_v

        except Exception as e:
            print(f"Error en la iteración {i}: {e}")
            break

        if reduce_lr_on_plateau:
            current_lr_value = optimizer.learning_rate.numpy()
        else:
            current_lr_value = optimizer.learning_rate(tf.cast(i + 1, tf.int64)).numpy() # `i + 1` porque el schedule se actualiza por pasos

        if verbose and (i + 1) % lr_monitor_interval == 0: # Imprimir en el paso `i+1` (100, 200, etc.)
            print(f"Iteración {i+1}: Loss (Train - {loss_metric_name}) = {loss_train_v:.6f}, "
                  f"Loss (Val - {loss_metric_name}) = {loss_val_v:.6f}, LR = {current_lr_value:.4f}")

        # --- Lógica de ReduceLROnPlateau ---
        # Si la iteración es 0, no aplicamos la lógica de paciencia en el primer check (ya inicializamos best_monitor_loss)
        # Empezamos a verificar la paciencia desde el primer intervalo de monitoreo.
        if reduce_lr_on_plateau and (i + 1) % lr_monitor_interval == 0 and i > 0: # No monitorear en la Iteración 0
            if monitor_loss < best_monitor_loss - lr_min_delta:
                best_monitor_loss = monitor_loss
                lr_wait = 0
                if verbose and (i + 1) % (lr_monitor_interval * 10) == 0:
                    print(f"Iteración {i+1}: Mejora en pérdida de {monitor_set_name} ({monitor_loss:.6f}). Reiniciando paciencia. Best_{monitor_set_name}_loss: {best_monitor_loss:.6f}")
            else:
                lr_wait += lr_monitor_interval
                if verbose and (i + 1) % (lr_monitor_interval * 10) == 0:
                    print(f"Iteración {i+1}: Sin mejora en pérdida de {monitor_set_name} ({monitor_loss:.6f}). Paciencia: {lr_wait}/{lr_patience}. Best_{monitor_set_name}_loss: {best_monitor_loss:.6f}")

                if lr_wait >= lr_patience:
                    old_lr = optimizer.learning_rate.numpy()
                    new_lr = max(old_lr * lr_factor, lr_min)
                    if new_lr < old_lr - 1e-10:
                        optimizer.learning_rate.assign(new_lr)
                        if verbose:
                            print(f"Iteración {i+1}: Paciencia agotada ({lr_patience}). Reduciendo LR de {old_lr:.4f} a {new_lr:.4f}.")
                        lr_wait = 0

        # --- Criterios de parada basados en la pérdida de entrenamiento ---
        # Ajustar el índice para la pérdida en el historial
        if i >= patience and (i % patience == 0 or i == max_iters - 1):
            if len(loss_hist) > patience:
                # Utilizamos `i+1` para el historial porque hemos añadido una entrada inicial
                # Por lo tanto, `loss_hist[0]` es la pérdida inicial, `loss_hist[1]` es la de la iteración 0, etc.
                # La pérdida `patience` iteraciones atrás será `loss_hist[i+1 - patience]`
                loss_prev = loss_hist[i + 1 - patience]
                improvement = (loss_prev - loss_train_v) / (loss_prev + tf.keras.backend.epsilon())

                if improvement < improvement_threshold:
                    if verbose:
                        print(f"Parando en iteración {i+1}: Mejora (entrenamiento {loss_metric_name}: {improvement:.6f}) < {improvement_threshold:.6f} durante {patience} iteraciones.")
                    break

        if loss_train_v < abs_loss_threshold:
            if verbose:
                print(f"Parando en iteración {i+1}: Loss (entrenamiento {loss_metric_name}) = {loss_train_v:.6e} por debajo del umbral absoluto ({abs_loss_threshold}).")
            break

    # --- Resumen Final de la Optimización ---
    if verbose:
        print("\n--- Resumen Final ---")
        if loss_hist:
            print(f"Pérdida final de entrenamiento ({loss_metric_name}): {loss_hist[-1]:.6f}")
            print(f"Pérdida final de validación ({loss_metric_name}): {val_loss_hist[-1]:.6f}")

            if loss.lower() == 'relative_l2':
                final_rmse_train = calculate_loss(estimated, A_Omega, Omega, metric_type='rmse').numpy()
                final_rmse_val = calculate_loss(estimated, A_Omega_C, Omega_C, metric_type='rmse').numpy()
                print(f"Pérdida final de entrenamiento (RMSE): {final_rmse_train:.6f}")
                print(f"Pérdida final de validación (RMSE): {final_rmse_val:.6f}")
        else:
            print("No se registraron pérdidas (posiblemente debido a un error temprano).")

    return estimated, loss_hist, val_loss_hist

import tensorflow as tf
import tensorflow_probability as tfp
import t3f
import numpy as np

def optimize_tt_with_lbfgs(
        X, A_Omega, Omega, A_Omega_C, Omega_C,
        d_param, number_nodes,
        max_iters=7500, verbose=False, loss='relative_l2',
        print_every=25):

    # ---------- Inicialización ----------
    estimated = t3f.get_variable('estimated', initializer=X)
    tt_cores_shapes = [core.shape for core in estimated.tt_cores]
    tt_ranks = t3f.tt_ranks(X).numpy()
    tensor_shape = estimated.get_shape()
    initial_params = tf.concat(
        [tf.cast(tf.reshape(c, [-1]), tf.float64) for c in estimated.tt_cores],
        axis=0)

    # Listas para almacenar las pérdidas por iteración
    train_losses_history = []
    val_losses_history = []

    if verbose:
        print("\n========== Comienzo optimización L-BFGS ==========")
        print("Formas núcleos TT :", [s.as_list() for s in tt_cores_shapes])
        print("Rangos TT         :", tt_ranks)
        print("Forma del tensor  :", tensor_shape.as_list())
        print(f"Parámetros totales: {initial_params.shape[0]}")
        print("---------------------------------------------------")
        print(f"{'Iter':>7} | {'Loss train':>14} | {'Loss val':>14}")
        print("---------------------------------------------------")

    # --------- Utilidades internas ---------
    def reconstruct_tt(params):
        offset, new_cores = 0, []
        for shape in tt_cores_shapes:
            size = np.prod(shape.as_list())
            core = tf.reshape(params[offset:offset+size], shape)
            new_cores.append(core)
            offset += size
        return t3f.TensorTrain(new_cores, shape=tensor_shape, tt_ranks=tt_ranks)

    def compute_loss(tt_tensor, indices, values):
        """Calcula la pérdida entre tt_tensor[indices] y values."""
        dif = t3f.gather_nd(tt_tensor, indices) - values

        if loss == 'relative_l2':
            eps = tf.constant(1e-12, dtype=values.dtype)
            return tf.norm(dif) / (tf.norm(values) + eps)
        # -------- 'mse' por defecto --------
        return tf.reduce_mean(tf.square(dif))

    # --------------- value + grad ---------------
    iter_counter = 0

    def loss_and_grad(params):
        nonlocal iter_counter
        nonlocal train_losses_history, val_losses_history # Acceder a las listas de histórico

        with tf.GradientTape() as tape:
            tape.watch(params)
            temp_est = reconstruct_tt(params)
            loss_value = compute_loss(temp_est, Omega, A_Omega)  # train-loss
        grads = tape.gradient(loss_value, params)

        # ---------- Logs progresivos y almacenamiento ----------
        iter_counter += 1

        # Calcular y guardar la pérdida de validación en cada step
        val_loss_now = compute_loss(temp_est, Omega_C, A_Omega_C).numpy()
        train_losses_history.append(loss_value.numpy())
        val_losses_history.append(val_loss_now)

        if verbose and (iter_counter == 1 or iter_counter % print_every == 0):
            print(f"{iter_counter:7d} | {loss_value.numpy():14.6e} | {val_loss_now:14.6e}")

        return loss_value, grads

    # --------------- L-BFGS ---------------
    result = tfp.optimizer.lbfgs_minimize(
        value_and_gradients_function=loss_and_grad,
        initial_position=initial_params,
        max_iterations=max_iters)

    final_params = result.position
    estimated = reconstruct_tt(final_params)

    # --------------- Métricas finales ---------------
    if verbose:
        print("\n--- Resumen L-BFGS ---")
        print("Iteraciones realizadas :", result.num_iterations.numpy())
        print(f"Loss train (final)     : {train_losses_history[-1]:.6f}") # Usar el último valor del historial
        print(f"Loss val (final)       : {val_losses_history[-1]:.6f}")   # Usar el último valor del historial
        print("Convergió?             :", bool(result.converged.numpy()))
        print("---------------------------------------------")

    # --- Salidas solicitadas ---
    return estimated, np.array(train_losses_history), np.array(val_losses_history), \
           bool(result.converged.numpy())

import tensorflow as tf
import t3f
import functools
import numpy as np
import matplotlib.pyplot as plt
from ttml.tensor_train import TensorTrain

def linearized_search(A_Omega, X_k, eta_k, Omega):
    X_k_proyectado = t3f.gather_nd(X_k, Omega)
    eta_k_proyectado = t3f.gather_nd(eta_k, Omega)

    A_Omega = tf.constant(A_Omega, dtype=tf.float64)
    resta = A_Omega - X_k_proyectado

    # Calcular el numerador y el denominador
    num = tf.tensordot(eta_k_proyectado, resta, axes=1)
    den = tf.tensordot(eta_k_proyectado, eta_k_proyectado, axes=1)

    size_A_Omega = tf.size(A_Omega)
    size_A_Omega_float = tf.cast(size_A_Omega, tf.float64)

    alpha_k =  num / den
    return alpha_k

def increase_tt_rank_mu(tt_tensor, mu, noise_magnitude=1e-8):
    """
    Incrementa el rango TT en la posición mu+1 siguiendo el algoritmo (4.7).

    Args:
        tt_tensor (t3f.TensorTrain): Tensor en formato TT.
        mu (int): Índice donde se incrementará el rango. Únicos valores posibles: mu = 1,...,d-1
        noise_magnitude(float, opcional): Magnitud de los vectores aleatorios R_mu y R_{mu+1}. Default = 1e-8.

    Returns:
        t3f.TensorTrain: Tensor TT con rango aumentado en la posición mu+1.
    """
    # Paso 1: Ortogonalizar desde la izquierda hasta la posición mu
    tt_ortho = t3f.orthogonalize_tt_cores(tt_tensor, left_to_right=True)

    # Paso 2: Extraer núcleos TT
    tt_cores = tt_ortho.tt_cores  # Lista de núcleos TT. Esto es un tf.Tensor, por eso podemos usar .tt_cores
    U_L = tt_cores[mu-1]            # Núcleo en la posición mu
    U_R = tt_cores[mu]          # Núcleo en la posición mu+1

    # Paso 3: Extraer el vector de rangos TT correctamente
    tt_ranks = t3f.tt_ranks(tt_tensor)
    r_mu_minus_1 = tt_ranks[mu-1]  # Rango TT en la posición mu-1
    r_mu = tt_ranks[mu]            # Rango TT en la posición mu
    r_mu_plus_1 = tt_ranks[mu+1]   # Rango TT en la posición mu+1

    # Extraer las dimensiones de los modos
    n_mu = tt_tensor.get_shape()[mu-1]       # Dimensión en la posición mu
    n_mu_plus_1 = tt_tensor.get_shape()[mu]  # Dimensión en la posición mu+1

    # Paso 4: Crear vectores aleatorios con la norma deseada
    R_mu = tf.random.normal(shape=[r_mu_minus_1 * n_mu, 1], mean=0.0, stddev=noise_magnitude)
    R_mu_plus_1 = tf.random.normal(shape=[1, r_mu_plus_1 * n_mu_plus_1], mean=0.0, stddev=noise_magnitude)

    # Paso 5: Redimensionar los vectores para que sean compatibles con la estructura TT
    R_mu = tf.reshape(R_mu, [r_mu_minus_1, n_mu, 1])  # Convertir a (r_{mu-1}, n_mu, 1)
    R_mu_plus_1 = tf.reshape(R_mu_plus_1, [1, n_mu_plus_1, r_mu_plus_1])  # Convertir a (1, n_{mu+1}, r_{mu+1})

    # Paso 6: Modificar los núcleos TT en las posiciones mu y mu+1
    U_L_new = tf.concat([U_L, R_mu], axis=2)  # Expandir en la tercera dimensión (r_mu+1)
    U_R_new = tf.concat([U_R, R_mu_plus_1], axis=0)  # Expandir en la primera dimensión (r_mu+1)

    # Paso 7: Reconstruir la lista de núcleos con las modificaciones
    new_cores = tt_cores[:mu-1] + (U_L_new, U_R_new) + tt_cores[mu+1:]

    # Mostrar la lista de cores y sus dimensiones (Opcional)
    # print("Lista de new_cores y sus dimensiones:")
    # for i, core in enumerate(new_cores):
        # print(f"Core {i}: Shape {core.shape}")

    # Paso 8: Crear el nuevo tensor TT con los núcleos modificados
    tt_tensor_updated = t3f.TensorTrain(new_cores)

    return tt_tensor_updated

def truncate(tensor, target_ranks):
    """Trunca un tensor TT a los rangos objetivo."""
    #return t3f.round(tensor, max_tt_rank=target_ranks, epsilon=1e-6)
    return t3f.round(tensor, max_tt_rank=target_ranks) # He cambiado esto, comprobar si epsilon=None funciona para nuestros propósitos

# Gradiente Riemanniano automático
def calcular_gradiente_riemanniano_tf(X, A_Omega, Omega):
    X_ortho = t3f.orthogonalize_tt_cores(X, left_to_right=True)
    funcion_objetivo_parcial = functools.partial(funcion_objetivo, A_Omega=A_Omega, Omega=Omega)
    gradiente = t3f.gradients(funcion_objetivo_parcial, X_ortho, runtime_check=False)
    return gradiente # El resultado es un TensorTrain de t3f

def funcion_objetivo(X, A_Omega, Omega):
    X_Omega = t3f.gather_nd(X, Omega)  # Usar t3f.gather_nd para TensorTrain
    A_Omega = tf.constant(A_Omega, dtype=tf.float64)
    size_A_Omega = tf.size(A_Omega)
    size_A_Omega_float = tf.cast(size_A_Omega, tf.float64)
    Z = X_Omega - A_Omega

    # Calcular el producto escalar de Z consigo mismo
    producto_escalar_Z = tf.tensordot(Z, Z, axes=1)

    return 0.5 * producto_escalar_Z / size_A_Omega_float
    #return 0.5 * producto_escalar_Z

def riemannian_tensor_completion(X_0, A_Omega, Omega, A_Omega_C, Omega_C, d, number_nodes,
                                 max_iters=100, tol_riemannian_tensor_completion=1e-4,
                                 loss='relative_l2', verbose=False, print_interval=100):
    """
    Realiza la compleción de tensores utilizando un método de gradiente conjugado Riemanniano.

    Args:
        X_0: Tensor TT inicial para la estimación.
        A_Omega: Tensor con los valores observados (entrenamiento).
        Omega: Índices de las entradas observadas (entrenamiento).
        A_Omega_C: Tensor con los valores observados (validación).
        Omega_C: Índices de las entradas observadas (validación).
        d: Número de dimensiones del tensor.
        number_nodes: Número de nodos por dimensión.
        max_iters: Número máximo de iteraciones.
        tol_riemannian_tensor_completion: Tolerancia para el criterio de parada.
        loss: Tipo de función de pérdida ('mse' o 'relative_l2').
        verbose: Si es True, imprime información sobre el progreso.
        print_interval: Cada cuántas iteraciones imprimir el estado.

    Returns:
        Un tuple que contiene:
            - X_k: El tensor TT estimado después de la optimización.
            - loss_hist: Lista con el historial de la pérdida de entrenamiento.
            - val_loss_hist: Lista con el historial de la pérdida de validación.
    """
    def calculate_loss(estimated_tensor, values, indices, metric_type='current_loss'):
        """
        Calcula diferentes tipos de pérdida o métricas.
        """
        estimated_vals = t3f.gather_nd(estimated_tensor, indices)
        diff = estimated_vals - values

        if metric_type == 'mse':
            return tf.reduce_mean(diff ** 2)
        elif metric_type == 'rmse':
            return tf.sqrt(tf.reduce_mean(diff ** 2))
        elif metric_type == 'relative_l2':
            norm_values = tf.norm(values)
            if tf.abs(norm_values) < tf.keras.backend.epsilon():
                return tf.norm(diff)
            return tf.norm(diff) / norm_values
        elif metric_type == 'current_loss':
            if loss.lower() == 'mse':
                return tf.reduce_mean(diff ** 2)
            elif loss.lower() == 'relative_l2':
                norm_values = tf.norm(values)
                if tf.abs(norm_values) < tf.keras.backend.epsilon():
                    return tf.norm(diff)
                return tf.norm(diff) / norm_values
            else:
                raise ValueError(f"Tipo de pérdida no válido: '{loss}'.")
        else:
            raise ValueError(f"Tipo de métrica no válido: '{metric_type}'.")

    # --- Inicialización ---
    X_k = X_0
    target_ranks = t3f.tt_ranks(X_0).numpy()
    loss_hist = []
    val_loss_hist = []
    loss_metric_name = "RMSE" if loss.lower() == 'mse' else loss.upper()

    # --- Informes Iniciales ---
    sizeOmega = tf.shape(Omega)[0]
    sizeOmega_C = tf.shape(Omega_C)[0]
    print("Starting Riemannian Tensor Completion for a target tensor.")
    print(f"Dimensions: {d}, Nodes per dimension: {number_nodes}, Total size: {number_nodes**d}")
    print(f"Initial ranks: {target_ranks}")
    print(f"Training points: {sizeOmega}, Validation points: {sizeOmega_C}")
    print(f"Max Iters: {max_iters}, Tolerance: {tol_riemannian_tensor_completion:.1e}")
    print("-" * 30)

    # --- Cálculo de la pérdida inicial ---
    initial_loss_train = calculate_loss(X_k, A_Omega, Omega, metric_type='rmse' if loss.lower() == 'mse' else 'current_loss').numpy()
    initial_loss_val = calculate_loss(X_k, A_Omega_C, Omega_C, metric_type='rmse' if loss.lower() == 'mse' else 'current_loss').numpy()
    loss_hist.append(initial_loss_train)
    val_loss_hist.append(initial_loss_val)

    if verbose:
        print(f"Estado Inicial (Iteración 0): Loss (Train - {loss_metric_name}) = {initial_loss_train:.6f}, "
              f"Loss (Val - {loss_metric_name}) = {initial_loss_val:.6f}")

    # --- Primera iteración (fuera del bucle principal para inicializar valores) ---
    xi_k = calcular_gradiente_riemanniano_tf(X_k, A_Omega, Omega)
    eta_k = -xi_k
    alpha_k = linearized_search(A_Omega, X_k, eta_k, Omega)
    X_temp = X_k + alpha_k * eta_k
    X_k_new = truncate(X_temp, target_ranks)

    # Guardar estado para la siguiente iteración
    ip_xi_xi_old = t3f.frobenius_norm_squared(xi_k)
    eta_k_anterior = eta_k
    X_k = X_k_new

    # --- Bucle de Optimización Principal ---
    for k in range(1, max_iters + 1):
        # Calcular gradiente y dirección conjugada
        xi_k = calcular_gradiente_riemanniano_tf(X_k, A_Omega, Omega)
        eta_transported = t3f.project(eta_k_anterior, X_k)

        ip_xi_xi = t3f.frobenius_norm_squared(xi_k)
        beta_k = ip_xi_xi / ip_xi_xi_old if ip_xi_xi_old != 0 else 0
        eta_k = -xi_k + beta_k * eta_transported

        # Búsqueda de paso y actualización del tensor
        alpha_k = linearized_search(A_Omega, X_k, eta_k, Omega)
        X_temp = X_k + alpha_k * eta_k
        X_k_new = truncate(X_temp, target_ranks)

        # Calcular y almacenar las pérdidas de la iteración actual
        loss_train_v = calculate_loss(X_k_new, A_Omega, Omega, metric_type='rmse' if loss.lower() == 'mse' else 'current_loss').numpy()
        loss_val_v = calculate_loss(X_k_new, A_Omega_C, Omega_C, metric_type='rmse' if loss.lower() == 'mse' else 'current_loss').numpy()
        loss_hist.append(loss_train_v)
        val_loss_hist.append(loss_val_v)

        # Informar del progreso
        if verbose and k % print_interval == 0:
            print(f"Iteración {k}: Loss (Train - {loss_metric_name}) = {loss_train_v:.6f}, "
                  f"Loss (Val - {loss_metric_name}) = {loss_val_v:.6f}")

        # Criterio de parada (basado en la diferencia relativa de los tensores)
        diff_norm = t3f.frobenius_norm(X_k_new - X_k)
        x_norm = t3f.frobenius_norm(X_k)
        relative_diff = diff_norm / (x_norm + tf.keras.backend.epsilon())
        '''
        if relative_diff < tol_riemannian_tensor_completion:
            if verbose:
                print(f"\nParando en iteración {k}: La mejora relativa ({relative_diff:.3e}) es menor que la tolerancia ({tol_riemannian_tensor_completion:.1e}).")
            X_k = X_k_new
            break
        '''
        # Actualizar variables para la siguiente iteración
        X_k = X_k_new
        eta_k_anterior = eta_k
        ip_xi_xi_old = ip_xi_xi

        if k == max_iters:
            if verbose:
                print(f"\nParando en iteración {k}: Se alcanzó el número máximo de iteraciones.")

    # --- Resumen Final de la Optimización ---
    if verbose:
        print("\n--- Resumen Final ---")
        if loss_hist:
            final_train_loss = loss_hist[-1]
            final_val_loss = val_loss_hist[-1]
            print(f"Pérdida final de entrenamiento ({loss_metric_name}): {final_train_loss:.6f}")
            print(f"Pérdida final de validación ({loss_metric_name}): {final_val_loss:.6f}")

            # Si la pérdida principal no era RMSE, calcularlo para el informe final
            if loss.lower() != 'mse':
                final_rmse_train = calculate_loss(X_k, A_Omega, Omega, metric_type='rmse').numpy()
                final_rmse_val = calculate_loss(X_k, A_Omega_C, Omega_C, metric_type='rmse').numpy()
                print(f"Pérdida final de entrenamiento (RMSE): {final_rmse_train:.6f}")
                print(f"Pérdida final de validación (RMSE): {final_rmse_val:.6f}")
        else:
            print("No se registraron pérdidas.")

    return X_k, loss_hist, val_loss_hist

# Función directora

In [4]:
import numpy as np
import t3f # Assuming t3f is installed and provides TensorTrain and other functionalities
import matplotlib.pyplot as plt # For plotting

# =============================================================================
# GLOBAL VARIABLES (for custom_functions)
# =============================================================================
total_vector_evaluations = 0
d = 10 # This global 'd' will be used by the custom functions (f1-f4)
# Declaring custom_function globally to be accessible by other parts of the pipeline
# This will be assigned within run_tensor_completion_pipeline
custom_function = None

# =============================================================================
# CUSTOM FUNCTIONS (f1 to f4)
# =============================================================================

def f1(x):
    """
    The actual function being approximated (f1).
    Takes a numpy array x of shape (num_samples, d).
    Increments total_vector_evaluations by num_samples.
    """
    global total_vector_evaluations
    global d

    if len(x.shape) == 1:
        x_input = x.reshape(1, -1)
    else:
        x_input = x

    num_samples, d_check = x_input.shape
    if d_check != d:
        raise ValueError(f"Dimension mismatch in f1: Expected {d}, got {d_check}")

    total_vector_evaluations += num_samples

    with np.errstate(divide='ignore', invalid='ignore', over='ignore', under='ignore'):
        x0 = x_input[:, 0]
        x1 = x_input[:, 1]
        x2 = x_input[:, 2]
        x3 = x_input[:, 3]
        x4 = x_input[:, 4]
        x5 = x_input[:, 5]
        x6 = x_input[:, 6]
        x7 = x_input[:, 7]
        x8 = x_input[:, 8]
        x9 = x_input[:, 9]

        term1 = x0 ** 2
        term2 = np.exp(x1)
        term3 = np.log(x2 + 1)
        term4 = 1 / (x3 + 1)
        term5 = x4 ** 3
        term6 = np.sqrt(x5)
        term7 = x6 * x7
        term8 = x8 ** 4
        term9 = -np.exp(-x9)

        values = term1 + term2 + term3 + term4 + term5 + term6 + term7 + term8 + term9
    return values

def f2(x):
    """
    The actual function being approximated (f2).
    Takes a numpy array x of shape (num_samples, d).
    Increments total_vector_evaluations by num_samples.
    """
    global total_vector_evaluations
    global d

    if len(x.shape) == 1:
        x_input = x.reshape(1, -1)
    else:
        x_input = x

    num_samples, d_check = x_input.shape
    if d_check != d:
        raise ValueError(f"Dimension mismatch in f2: Expected {d}, got {d_check}")

    total_vector_evaluations += num_samples

    with np.errstate(divide='ignore', invalid='ignore', over='ignore', under='ignore'):
        values = ((x_input[:, 0] ** 2) * (x_input[:, 9] ** 2) * (x_input[:, 8] ** 3) +
                  np.exp(-x_input[:, 1] ** 2 + x_input[:, 5] ** 4 - x_input[:, 7] * x_input[:, 6]) +
                  np.log(1 + x_input[:, 3] + x_input[:, 2] * x_input[:, 7]) +
                  (x_input[:, 0] ** 2 + x_input[:, 1] + x_input[:, 4]) /
                  (1 + x_input[:, 5] ** 3 + x_input[:, 2] ** 2 + x_input[:, 8] ** 4))
    return values

def f3(x):
    """
    The actual function being approximated (f3).
    Takes a numpy array x of shape (num_samples, d).
    Increments total_vector_evaluations by num_samples.
    """
    global total_vector_evaluations
    global d

    if len(x.shape) == 1:
        x_input = x.reshape(1, -1)
    else:
        x_input = x

    num_samples, d_check = x_input.shape
    if d_check != d:
        raise ValueError(f"Dimension mismatch in f3: Expected {d}, got {d_check}")

    total_vector_evaluations += num_samples

    with np.errstate(divide='ignore', invalid='ignore', over='ignore', under='ignore'):
        term1_base = (x_input[:, 0] ** 2) + 5 * np.sin(4 * np.pi * x_input[:, 0])
        term1 = term1_base * (x_input[:, 9] ** 2) * (x_input[:, 8] ** 3)

        term2 = np.exp(-x_input[:, 1] ** 2 + x_input[:, 5] ** 4 - x_input[:, 7] * x_input[:, 6])
        term3 = np.log(1 + x_input[:, 3] + x_input[:, 2] * x_input[:, 7])

        numerator_term4 = x_input[:, 0] ** 2 + x_input[:, 1] + x_input[:, 4]
        denominator_term4 = 1 + x_input[:, 5] ** 3 + x_input[:, 2] ** 2 + x_input[:, 8] ** 4
        term4 = numerator_term4 / denominator_term4

        term_min_x1 = 50 * (x_input[:, 1] - 0.5) ** 2

        values = term1 + term2 + term3 + term4 + term_min_x1
    return values

def f4(x):
    """
    Una nueva función compleja de 10 dimensiones para problemas de compleción tensorial (f4).
    Esta función incorpora una variedad de términos no lineales e interactivos
    para crear un paisaje funcional rico y diverso.

    La función toma un array de numpy 'x' con forma (num_samples, d).
    Incrementa la variable global 'total_vector_evaluations' por 'num_samples',
    contabilizando el número total de vectores procesados en esta llamada.

    Parámetros:
    ----------
    x : numpy.ndarray
        Un array de numpy que representa los puntos de entrada. Puede ser:
        - De forma (d,) para una única evaluación vectorial.
        - De forma (num_samples, d) para una evaluación por lotes de múltiples vectores.

    Retorna:
    -------
    numpy.ndarray
        Un array de numpy con las evaluaciones de la función para cada vector de entrada.
        La forma será (num_samples,) o un escalar si la entrada fue (d,).

    Levanta:
    -------
    ValueError
        Si la dimensión de entrada (d_check) no coincide con la dimensión esperada 'd'.
    """
    global total_vector_evaluations
    global d

    # Asegura que la entrada sea bidimensional (num_samples, d) para facilitar el procesamiento por lotes
    if len(x.shape) == 1:
        x_input = x.reshape(1, -1)
    else:
        x_input = x

    # Verifica que la dimensión de entrada coincida con la dimensión esperada 'd'
    num_samples, d_check = x_input.shape
    if d_check != d:
        raise ValueError(f"Desajuste de dimensión en complex_multidimensional_function: Esperado {d}, obtenido {d_check}")

    # Incrementa el contador global de evaluaciones de vectores
    total_vector_evaluations += num_samples

    # Calcula los valores para el lote de entrada, suprimiendo advertencias
    # para operaciones como divisiones por cero, logaritmos de cero/negativos, etc.,
    # lo cual es útil si los puntos de entrada no están estrictamente dentro del dominio deseado.
    with np.errstate(divide='ignore', invalid='ignore', over='ignore', under='ignore'):
        # Término 1: Interacción oscilatoria entre x0 y x1
        # Multiplica senos y cosenos para crear patrones de onda que dependen de ambas dimensiones.
        term1 = 10 * np.sin(2 * np.pi * x_input[:, 0]) * np.cos(3 * np.pi * x_input[:, 1])

        # Término 2: Función sigmoide que depende de la suma de x2, x3 y x4
        # Esto genera una transición suave (en forma de 'S') en el valor de la función
        # a medida que la suma de estas tres dimensiones varía.
        term2 = 5 / (1 + np.exp(-(x_input[:, 2] + x_input[:, 3] + x_input[:, 4] - 1.5) * 5))

        # Término 3: Pico Gaussiano (similar a una RBF) centrado en (0.5, 0.5, 0.5) para x5, x6, x7
        # Este término crea una 'montaña' o 'valle' localizada en el espacio de la función,
        # contribuyendo a la complejidad local. La varianza (0.05) controla el ancho del pico.
        term3 = 15 * np.exp(-((x_input[:, 5] - 0.5)**2 + (x_input[:, 6] - 0.5)**2 + (x_input[:, 7] - 0.5)**2) / 0.05)

        # Término 4: Tangente hiperbólica de un producto polinomial de x8 y x9
        # Introduce una no linealidad fuerte y acotada, que es común en funciones de activación de redes neuronales.
        term4 = 8 * np.tanh(x_input[:, 8] * x_input[:, 9]**2 - 0.5)

        # Término 5: Interacciones polinómicas lineales cruzadas entre dimensiones
        # Aporta una contribución lineal de productos de pares de dimensiones.
        term5 = (x_input[:, 0] * x_input[:, 5] +
                 x_input[:, 1] * x_input[:, 6] +
                 x_input[:, 2] * x_input[:, 7] +
                 x_input[:, 3] * x_input[:, 8] +
                 x_input[:, 4] * x_input[:, 9])

        # Combina todos los términos para obtener el valor final de la función
        values = term1 + term2 + term3 + term4 + term5

        # Reemplaza cualquier NaN (Not a Number) o Inf (Infinito) resultante de operaciones numéricas
        # con valores finitos para asegurar la robustez.
        values = np.nan_to_num(values, nan=0.0, posinf=1e10, neginf=-1e10)
    return values

# Map of function names to actual function objects
function_map = {
    'f1': f1,
    'f2': f2,
    'f3': f3,
    'f4': f4
}

# =============================================================================
# OPTIMIZATION PIPELINE FUNCTION
# =============================================================================

def run_tensor_completion_pipeline(function_choice, optimization_method,
                                 d_param, MODAL_SIZE, sizeOmega_max,
                                 min_rank_to_try, max_rank_to_try, manual_rank, seed_tt_cross,
                                 tol_flattening, max_its_tt_cross, tol_precision,
                                 metric, tt_cross_algorithm_func, dmrg_rank_kick,
                                 sizeOmega_C, TEST_SET_SEED, increase_amount, sizeOmegaExtra,
                                 verbose=True, graphics=True):
    """
    Executes the tensor completion pipeline, including TT-Cross and an optional optimization method.

    Args:
        function_choice (str): The name of the function to approximate ('f1', 'f2', 'f3', 'f4').
        optimization_method (str or None): The optimization method to use ('adam', 'lbfgs', 'rttc', or None).
        d_param (int): Dimension of the tensor.
        MODAL_SIZE (int): Desired modal size (number of points per dimension).
        sizeOmega_max (int): Maximum allowed total vector evaluations for TT-Cross.
        min_rank_to_try (int): Minimum TT rank to try for TT-Cross.
        max_rank_to_try (int): Maximum TT rank to try for TT-Cross.
        manual_rank (int or None): Manual rank chosen for the TT approximation.
        seed_tt_cross (int): Seed for TT-Cross randomness.
        tol_flattening (float): Flattening tolerance for TT-Cross.
        max_its_tt_cross (int): Maximum number of sweeps for TT-Cross.
        tol_precision (float): Precision tolerance (target loss) for TT-Cross.
        metric (str): Loss metric ('rmse' or 'relative_l2').
        tt_cross_algorithm_func (callable): The TT-Cross algorithm function (e.g., tt_cross_regular_v2).
        dmrg_rank_kick (int): DMRG rank kick parameter.
        sizeOmega_C (int): Size of the validation set.
        TEST_SET_SEED (int): Seed for test set generation.
        increase_amount (int): Amount to increase TT rank before optimization.
        sizeOmegaExtra (int): Extra points to augment training set.
        verbose (bool, optional): Whether to print verbose output. Defaults to True.
        graphics (bool, optional): Whether to plot graphics. Defaults to True.

    Returns:
        tuple: A tuple containing (X_optimized_final, loss_hist_final, val_loss_hist_final).
               X_optimized_final: The optimized TT approximation (or best_tt_approx if no optimization).
               loss_hist_final: List of training loss history (None if no optimization).
               val_loss_hist_final: List of validation loss history (None if no optimization).
    """
    global total_vector_evaluations
    global d # Ensure global 'd' is set to the parameter 'd_param' for consistency
    global custom_function # Declare custom_function as global to modify it

    d = d_param
    total_vector_evaluations = 0 # Reset evaluation counter for this pipeline run

    # Select the custom_function based on choice and assign it to the global variable
    if function_choice not in function_map:
        raise ValueError(f"Invalid function_choice: {function_choice}. Choose from {list(function_map.keys())}")
    custom_function = function_map[function_choice] # Assign to global custom_function

    print(f"\n--- Running Tensor Completion Pipeline for Function: {function_choice} ---")

    # --- TT-Cross Step ---
    print(f"\n--- Running TT-Cross for Function {function_choice} ---")
    Omega_C, A_Omega_C = generate_validation_set(
        sizeOmega_C=sizeOmega_C, d=d, number_nodes=MODAL_SIZE,
        custom_function=custom_function, # Now uses the global custom_function, which is correctly set
        seed=TEST_SET_SEED
    )

    best_tt_approx, chosen_error_value_val, chosen_rank_found, \
    chosen_evaluations_count, Omega, A_Omega = optimize_tt_cross_rank_sweep(
        d=d, MODAL_SIZE=MODAL_SIZE, sizeOmega=sizeOmega_max,
        min_rank_to_try=min_rank_to_try, max_rank_to_try=max_rank_to_try,
        seed=seed_tt_cross, tol_flattening=tol_flattening, max_its=max_its_tt_cross,
        tol_precision=tol_precision,
        physical_point_index_fun=physical_point_index_fun,
        collecting_index_fun=collecting_index_fun,
        tt_cross_algorithm_func=tt_cross_algorithm_func,
        create_tt_random=create_tt_random, create_tt_initial=create_tt_initial,
        get_tt_shape=get_tt_shape, TensorTrain=TensorTrain,
        index_function_wrapper=index_function_wrapper,
        Omega_C=Omega_C, A_Omega_C=A_Omega_C,
        _calculate_tt_approximation_error_value=_calculate_tt_approximation_error_value,
        process_and_verify=process_and_verify,
        dmrg_rank_kick=dmrg_rank_kick, metric=metric, manual_rank=manual_rank
    )

    check_approximation_accuracy(best_tt_approx, Omega, A_Omega, Omega_C, A_Omega_C, metric='relative_l2')
    print(f"TT-Cross final validation error for Function {function_choice}: {chosen_error_value_val:.6f}")
    print(f"Total vector evaluations after TT-Cross: {total_vector_evaluations}")


    # --- Handle No Optimization Case ---
    if optimization_method is None:
        print(f"\nOptimization method set to None. Stopping after TT-Cross for Function {function_choice}.")
        # Return best_tt_approx as the 'optimized' result, with no loss history
        return best_tt_approx, None, None

    # --- Prepare for Optimization Step ---
    best_tt_approx_t3f = process_tt_approximation_and_convert(best_tt_approx, increase_amount)
    seed_OmegaExtra = TEST_SET_SEED + 23 # Consistent seed for augmentation
    Omega, A_Omega = augment_training_set(Omega, A_Omega, sizeOmegaExtra, MODAL_SIZE, d, custom_function, seed_OmegaExtra) # Uses global custom_function
    print(f"Total vector evaluations after training set augmentation: {total_vector_evaluations}")

    # Determine the suffix based on manual_rank. If manual_rank is None, use '4'.
    current_rank_suffix = manual_rank if manual_rank is not None else 4

    # Variables to store the results of the chosen optimizer
    X_optimized_current = None
    loss_hist_current = None
    val_loss_hist_current = None
    method_title_suffix = ""
    method_key_prefix = ""

    # --- Logic for selecting the Optimization Method ---
    if optimization_method == 'adam':
        print(f"\n--- Running with Adam optimizer for Function {function_choice} ---")
        tt_shape = get_tt_shape(best_tt_approx)
        d_optimizer = len(tt_shape) # Redefine d based on tt_shape for optimizer

        X = best_tt_approx_t3f

        # Adam Hyperparameters
        learning_rate_initial = 0.000001
        decay_steps = 700
        decay_rate = 0.7
        max_iters_adam = 50000 # Specific max_iters for Adam
        abs_loss_threshold = 1e-4
        improvement_threshold = 1e-4
        patience_adam = 15000
        lr_reduce_on_plateau = True
        lr_patience = 1000
        lr_factor = 0.7
        lr_min_delta = 1e-5
        lr_min = 1e-6
        lr_monitor_interval = 100
        reduce_lr_train_set = True

        # Execute tensor completion algorithm with Adam
        X_optimized_current, loss_hist_current, val_loss_hist_current = optimize_tt_with_adam(
            X, A_Omega, Omega, A_Omega_C, Omega_C, d_optimizer, MODAL_SIZE,
            max_iters=max_iters_adam,
            abs_loss_threshold=abs_loss_threshold,
            improvement_threshold=improvement_threshold,
            patience=patience_adam,
            learning_rate_initial=learning_rate_initial,
            decay_steps=decay_steps,
            decay_rate=decay_rate,
            reduce_lr_on_plateau=lr_reduce_on_plateau,
            lr_patience=lr_patience,
            lr_factor=lr_factor,
            lr_min_delta=lr_min_delta,
            lr_min=lr_min,
            lr_monitor_interval=lr_monitor_interval,
            reduce_lr_train_set=reduce_lr_train_set,
            verbose=verbose,
            loss=metric # Use global metric for loss
        )
        method_key_prefix = 'adam'
        method_title_suffix = "Adam Optimization"

    elif optimization_method == 'lbfgs':
        print(f"\n--- Running with L-BFGS optimizer for Function {function_choice} ---")
        tt_shape = get_tt_shape(best_tt_approx)
        d_optimizer = len(tt_shape) # Redefine d based on tt_shape for optimizer

        X = best_tt_approx_t3f

        max_iters_lbfgs = 2500 # Specific max_iters for L-BFGS
        print_every = 50

        # Execute tensor completion algorithm with L-BFGS
        X_optimized_current, loss_hist_current, \
        val_loss_hist_current, converged_lbfgs = optimize_tt_with_lbfgs(
            X, A_Omega, Omega, A_Omega_C, Omega_C,
            d_param=d_optimizer, number_nodes=MODAL_SIZE,
            max_iters=max_iters_lbfgs, verbose=verbose, loss=metric, # Use global metric for loss
            print_every=print_every
        )
        method_key_prefix = 'lbfgs'
        method_title_suffix = "L-BFGS Optimization"

    elif optimization_method == 'rttc':
        print(f"\n--- Running with Riemannian Tensor Completion (RTTC) optimizer for Function {function_choice} ---")
        max_iters_rttc = 1000 # Default max_iters for RTTC
        tol_rttc = 1e-4 # Specific tolerance for RTTC
        print_interval_rttc = 100

        # Execute optimization with RTTC
        X_optimized_current, loss_hist_current, val_loss_hist_current = riemannian_tensor_completion(
            best_tt_approx_t3f, A_Omega, Omega, A_Omega_C, Omega_C, d, MODAL_SIZE,
            max_iters=max_iters_rttc, tol_riemannian_tensor_completion=tol_rttc,
            loss=metric, verbose=verbose, print_interval=print_interval_rttc # Use global metric for loss
        )
        method_key_prefix = 'rttc'
        method_title_suffix = "Riemannian Optimization"

    else:
        raise ValueError("Invalid optimization_method. Choose 'adam', 'lbfgs', 'rttc', or None.")

    # Use the retrieved variables in plotting and verification functions
    if graphics and loss_hist_current is not None: # Only plot if loss history exists (i.e., not None optimization)
        plot_loss_history(loss_hist_current, val_loss_hist_current, loss_metric_name='Relative L2 Error',
                          title=f'{method_title_suffix} (Función {function_choice}, Rango {current_rank_suffix}): Error L2 Relativo vs. Iteración')
    elif graphics:
        print("No loss history to plot for this run (optimization method is None).")

    check_approximation_accuracy(X_optimized_current, Omega, A_Omega, Omega_C, A_Omega_C, metric='relative_l2')
    check_approximation_accuracy(X_optimized_current, Omega, A_Omega, Omega_C, A_Omega_C, metric='relative_l1')
    print(f"Total vector evaluations after optimization: {total_vector_evaluations}")

    return X_optimized_current, loss_hist_current, val_loss_hist_current

# Pruebas

Antiguo:

In [18]:
tol_flattening = 1e-2 # Flattening tolerance (based on percentage change of loss)
tol_precision = 5e-3 # Precision tolerance (target loss)

In [21]:
# =============================================================================
# MAIN PARAMETER CHOICE AND PIPELINE EXECUTION
# =============================================================================

# --- Choose Function and Optimization Method ---
function_to_approximate = 'f2' # Choose 'f1', 'f2', 'f3', or 'f4'
optimization_method_choice = None # Choose 'adam', 'lbfgs', 'rttc', or None

# Main Parameters
d = 10 # Dimension of the tensor/function (also sets the global 'd')
MODAL_SIZE = 10 # Desired modal size (number of points per dimension)
number_nodes = MODAL_SIZE # Alias for MODAL_SIZE

sizeOmega_max = 10000 # Maximum allowed total vector evaluations for TT-Cross

min_rank_to_try = 1
max_rank_to_try = 10
manual_rank = 9 # Set manual_rank to an integer (e.g., 5) or None
seed_tt_cross = 610022

# TT-Cross Hyperparameters
tol_flattening = 1e-4 # Flattening tolerance (based on percentage change of loss)
tol_precision = 1e-4 # Precision tolerance (target loss)
max_its_tt_cross = 20 # Maximum number of sweeps for TT-Cross

tol_flattening = 1e-2 # Flattening tolerance (based on percentage change of loss)
tol_precision = 5e-3 # Precision tolerance (target loss)


metric = 'relative_l2' # 'rmse' or 'relative_l2'

tt_cross_algorithm_func = tt_cross_regular_v2 # tt_cross_regular_v2 or tt_cross_regular_v2_dmrg_step
dmrg_rank_kick = 0

# Test Set Generation Parameters
sizeOmega_C = 2000
TEST_SET_SEED = 610014611

# Parameters for Post-TT-Cross processing
increase_amount = 0 # Integer that determines how many units the TT rank is increased
sizeOmegaExtra = 0 # Extra points to augment training set

# Call the pipeline function to run the selected configuration
X_optimized_final, loss_hist_final, val_loss_hist_final = run_tensor_completion_pipeline(
    function_choice=function_to_approximate,
    optimization_method=optimization_method_choice,
    d_param=d,
    MODAL_SIZE=MODAL_SIZE,
    sizeOmega_max=sizeOmega_max,
    min_rank_to_try=min_rank_to_try,
    max_rank_to_try=max_rank_to_try,
    manual_rank=manual_rank,
    seed_tt_cross=seed_tt_cross,
    tol_flattening=tol_flattening,
    max_its_tt_cross=max_its_tt_cross,
    tol_precision=tol_precision,
    metric=metric,
    tt_cross_algorithm_func=tt_cross_algorithm_func,
    dmrg_rank_kick=dmrg_rank_kick,
    sizeOmega_C=sizeOmega_C,
    TEST_SET_SEED=TEST_SET_SEED,
    increase_amount=increase_amount,
    sizeOmegaExtra=sizeOmegaExtra,
    verbose=True,
    graphics=True
)

# --- Final Results Summary ---
print("\n--- Pipeline Execution Complete ---")
if X_optimized_final is not None:
    # Assuming X_optimized_final is a t3f.TensorTrain object or similar
    print(f"Final optimized TT approximation (first core shape): {X_optimized_final.cores[0].shape if hasattr(X_optimized_final, 'cores') and X_optimized_final.cores else 'N/A'}")
if loss_hist_final is not None:
    print(f"Final training loss: {loss_hist_final[-1]:.6f}")
    print(f"Final validation loss: {val_loss_hist_final[-1]:.6f}")
else:
    print("No optimization method was chosen, so no final training/validation loss history from optimization.")


--- Running Tensor Completion Pipeline for Function: f2 ---

--- Running TT-Cross for Function f2 ---
Iniciando barrido de rangos para TT-Cross (Métrica: Error Relativo L2)
Validación externa: 2000 puntos.
------------------------------

Probando con rango TT objetivo: 1
  Rango TT inicial: (1, 1, 1, 1, 1, 1, 1, 1, 1)
  Forma TTML inicial: (10, 10, 10, 10, 10, 10, 10, 10, 10, 10)
  Rango TT final de la aproximación: (1, 1, 1, 1, 1, 1, 1, 1, 1)
  Evaluaciones totales de vectores (llamadas a physical_point_index_fun): 2000
  Pares (índice, valor) recolectados individualmente: 2000
  Iteraciones completadas: 20
  Razón de parada del algoritmo TT-Cross: max_iterations
  Tiempo de ejecución para este rango: 0.40 segundos
  [92mError Relativo L2 (en Conjunto de Validación Omega_C, recalculado): 2.7283022481e-01[0m
  ¡Nueva mejor aproximación automática encontrada!
    Rango: 1, Error Relativo L2 (Validación): 2.7283022481e-01, Evaluaciones: 2000

Probando con rango TT objetivo: 2
  Rango 