In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import math
import time
from scipy.stats import norm
from collections import deque, defaultdict


# --- Environment and Utility Functions ---
def encode_15puzzle_state(state):
    """Encodes the 15-puzzle state into a feature vector"""
    encoded = np.zeros(16 * 2 * 4)
    for tile in range(16):
        idx = state.index(tile)
        row, col = divmod(idx, 4)
        encoded[tile * 8 + row] = 1
        encoded[tile * 8 + 4 + col] = 1
    return encoded


def get_valid_moves(state):
    """Returns all valid moves from the current state"""
    zero_index = state.index(0)
    row, col = divmod(zero_index, 4)
    valid_moves = []
    directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    for dr, dc in directions:
        new_row, new_col = row + dr, col + dc
        if 0 <= new_row < 4 and 0 <= new_col < 4:
            new_zero_index = new_row * 4 + new_col
            new_state = state[:]
            new_state[zero_index], new_state[new_zero_index] = (
                new_state[new_zero_index],
                new_state[zero_index],
            )
            valid_moves.append(new_state)
    return valid_moves


def manhattan_distance(state, goal_state):
    """Computes Manhattan distance heuristic"""
    total = 0
    for i in range(1, 16):
        curr_idx = state.index(i)
        goal_idx = goal_state.index(i)
        curr_row, curr_col = divmod(curr_idx, 4)
        goal_row, goal_col = divmod(goal_idx, 4)
        total += abs(curr_row - goal_row) + abs(curr_col - goal_col)
    return total


# --- Neural Network Definitions ---
class BayesianLinear(nn.Module):
    """Bayesian linear layer with local reparameterization"""

    def __init__(self, in_features, out_features, prior_mu=0.0, prior_sigma=1.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Initialize parameters with given priors
        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
        self.weight_logvar = nn.Parameter(torch.empty(out_features, in_features))
        self.bias_mu = nn.Parameter(torch.empty(out_features))
        self.bias_logvar = nn.Parameter(torch.empty(out_features))

        # Initialize with prior values
        nn.init.normal_(self.weight_mu, mean=prior_mu, std=prior_sigma / 10)
        nn.init.constant_(self.weight_logvar, math.log(prior_sigma**2))
        nn.init.normal_(self.bias_mu, mean=prior_mu, std=prior_sigma / 10)
        nn.init.constant_(self.bias_logvar, math.log(prior_sigma**2))

        # Prior parameters
        self.prior_mu = prior_mu
        self.prior_sigma = prior_sigma

    def forward(self, x):
        # Local reparameterization trick
        weight_var = torch.exp(self.weight_logvar)
        bias_var = torch.exp(self.bias_logvar)

        mu_out = F.linear(x, self.weight_mu, self.bias_mu)
        var_out = F.linear(x.pow(2), weight_var, bias_var)

        eps = torch.randn_like(mu_out)
        return mu_out + torch.sqrt(var_out + 1e-8) * eps

    def kl_divergence(self):
        """Computes KL divergence between posterior and prior"""
        kl_weight = 0.5 * (
            (self.weight_mu - self.prior_mu).pow(2) + torch.exp(self.weight_logvar)
        ) / (self.prior_sigma**2) - 0.5 * (
            1 + self.weight_logvar - math.log(self.prior_sigma**2)
        )
        kl_bias = 0.5 * (
            (self.bias_mu - self.prior_mu).pow(2) + torch.exp(self.bias_logvar)
        ) / (self.prior_sigma**2) - 0.5 * (
            1 + self.bias_logvar - math.log(self.prior_sigma**2)
        )
        return kl_weight.sum() + kl_bias.sum()


class WUNN(nn.Module):
    """Weight Uncertainty Neural Network for epistemic uncertainty"""

    def __init__(self, input_dim, hidden_dim=20, S=5, prior_mu=0.0, prior_sigma=1.0):
        super().__init__()
        self.S = S  # Number of samples for forward pass
        self.fc1 = BayesianLinear(input_dim, hidden_dim, prior_mu, prior_sigma)
        self.fc2 = BayesianLinear(hidden_dim, 1, prior_mu, prior_sigma)

    def forward_single(self, x):
        """Single forward pass"""
        x = F.relu(self.fc1(x))
        return self.fc2(x)

    def predict_sigma_e(self, x, K=100):
        """Predicts epistemic uncertainty"""
        self.eval()
        x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            outputs = [self.forward_single(x_tensor).item() for _ in range(K)]
        return np.var(outputs)


class FFNN(nn.Module):
    """Feedforward Neural Network for aleatoric uncertainty"""

    def __init__(self, input_dim, hidden_dim=20, dropout_rate=0.1):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2_mean = nn.Linear(hidden_dim, 1)
        self.fc2_var = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(dropout_rate)

        # Initialize weights
        nn.init.kaiming_normal_(self.fc1.weight, mode="fan_in", nonlinearity="relu")
        nn.init.xavier_normal_(self.fc2_mean.weight)
        nn.init.xavier_normal_(self.fc2_var.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2_mean.bias)
        nn.init.zeros_(self.fc2_var.bias)

    def forward(self, x):
        """Forward pass returning mean and log variance"""
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2_mean(x), self.fc2_var(x)

    def predict(self, x):
        """Predicts mean and variance"""
        self.eval()
        with torch.no_grad():
            mean, logvar = self.forward(x)
            return mean.item(), torch.exp(logvar).item()


# --- Main Algorithm Implementation ---
class LearnHeuristicPrac:
    def __init__(self, input_dim, goal_state, params):
        # Initialize models with prior parameters
        self.nnWUNN = WUNN(
            input_dim,
            params["hidden_dim"],
            prior_mu=params["mu0"],
            prior_sigma=math.sqrt(params["sigma2_0"]),
        )
        self.nnFFNN = FFNN(input_dim, params["hidden_dim"], params["dropout_rate"])        

        # Algorithm parameters
        self.alpha = params["alpha0"]
        self.beta = params["beta0"]
        self.epsilon = params["epsilon"]
        self.delta = params["delta"]
        self.kappa = params["kappa"]
        self.gamma = (0.00001 / params["beta0"]) ** (1 / params["NumIter"])  # Key change: Compute γ
        self.q = params["q"]
        self.K = params["K"]
        self.max_steps = params["MaxSteps"]

        # Training parameters
        self.NumIter = params["NumIter"]
        self.NumTasksPerIter = params["NumTasksPerIter"]
        self.NumTasksPerIterThresh = params["NumTasksPerIterThresh"]
        self.TrainIter = params["TrainIter"]
        self.MaxTrainIter = params["MaxTrainIter"]
        self.MiniBatchSize = params["MiniBatchSize"]
        self.tmax = params["tmax"]

        # Memory buffer
        self.memoryBuffer = deque(maxlen=params["MemoryBufferMaxRecords"])

        # Metrics tracking
        self.planner_costs = []
        self.optimal_costs = []
        self.suboptimalities = []
        self.optimality_counts = 0
        self.goal_state = goal_state

    def h(self, alpha, mu, sigma):
        """Quantile function of normal distribution"""
        return mu + sigma * norm.ppf(alpha)

    def generate_task(self):
        """Generates a task with high epistemic uncertainty (Algorithm 3)"""
        s_prime = self.goal_state[:]
        s_double_prime = None

        for _ in range(self.max_steps):
            states = {}
            valid_moves = get_valid_moves(s_prime)

            for s in valid_moves:
                if s_double_prime is not None and s == s_double_prime:
                    continue

                x = encode_15puzzle_state(s)
                sigma2_e = self.nnWUNN.predict_sigma_e(x, self.K)
                states[tuple(s)] = sigma2_e  # Use tuple as dict key

            if not states:
                break

            # Softmax sampling
            states_list = list(states.items())
            state_tuples, sigmas = zip(*states_list)
            probs = F.softmax(torch.tensor(sigmas), dim=0).numpy()
            selected_idx = np.random.choice(len(state_tuples), p=probs)
            selected_state = list(state_tuples[selected_idx])
            selected_sigma = sigmas[selected_idx]

            if selected_sigma >= self.epsilon:
                return {
                    "s": selected_state,
                    "sg": self.goal_state,
                    "sigma2_e": selected_sigma,
                }

            s_double_prime = s_prime
            s_prime = selected_state

        return None

    def ida_star(self, start, goal, heuristic, tmax, start_time):
        """IDA* implementation with time limit"""
        threshold = heuristic(start)
        path = [start]

        def search(g, bound):
            if time.time() - start_time > tmax:
                raise TimeoutError()

            node = path[-1]
            f = g + heuristic(node)
            if f > bound:
                return f
            if node == goal:
                return True

            min_t = float("inf")
            for neighbor in get_valid_moves(node):
                if neighbor in path:
                    continue
                path.append(neighbor)
                t = search(g + 1, bound)
                if t is True:
                    return True
                if t < min_t:
                    min_t = t
                path.pop()
            return min_t

        while True:
            t = search(0, threshold)
            if t is True:
                return path
            if t == float("inf"):
                return None
            threshold = t

    def uncertainty_aware_heuristic(self, state):
        """Computes h(s) = max(h(α, ŷ(s), σ_t(s)), 0)"""
        x = encode_15puzzle_state(state)
        x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(0)

        # Get FFNN predictions
        self.nnFFNN.eval()
        with torch.no_grad():
            mean, logvar = self.nnFFNN(x_tensor)
            y_hat = mean.item()
            sigma2_a = torch.exp(logvar).item()

        # Determine which variance to use
        sigma2_t = sigma2_a if y_hat < self.yq else self.epsilon

        # Compute heuristic value
        h_val = self.h(self.alpha, y_hat, math.sqrt(sigma2_t))
        return max(h_val, 0)

    def compute_metrics(self):
        """Compute suboptimality and optimality metrics"""
        if not self.planner_costs:
            return 0.0, 0.0

        # Calculate suboptimality (u_i)
        suboptimalities = [
            (y / y_star) - 1
            for y, y_star in zip(self.planner_costs, self.optimal_costs)
            if y_star > 0  # Avoid division by zero
        ]
        avg_suboptimality = (
            sum(suboptimalities) / len(suboptimalities) if suboptimalities else 0.0
        )

        # Calculate optimality rate (% tasks solved optimally)
        optimality_rate = (
            (self.optimality_counts / len(self.planner_costs)) * 100
            if self.planner_costs
            else 0.0
        )

        return avg_suboptimality, optimality_rate

    def train_ffnn(self):
        """Trains FFNN on entire memory buffer"""
        if len(self.memoryBuffer) < self.MiniBatchSize:
            return

        optimizer = optim.Adam(self.nnFFNN.parameters())
        criterion = nn.GaussianNLLLoss()

        # Convert memory buffer to tensors
        x_data = torch.stack(
            [torch.tensor(x, dtype=torch.float32) for x, _ in self.memoryBuffer]
        )
        y_data = torch.tensor(
            [y for _, y in self.memoryBuffer], dtype=torch.float32
        ).unsqueeze(1)

        self.nnFFNN.train()
        for _ in range(self.TrainIter):
            # Shuffle and batch the data
            permutation = torch.randperm(len(x_data))
            for i in range(0, len(x_data), self.MiniBatchSize):
                indices = permutation[i : i + self.MiniBatchSize]
                x_batch, y_batch = x_data[indices], y_data[indices]

                optimizer.zero_grad()
                mean, logvar = self.nnFFNN(x_batch)
                loss = criterion(mean, y_batch, torch.exp(logvar))
                loss.backward()
                optimizer.step()

    def train_wunn(self):
        """Trains WUNN with early stopping condition (Algorithm 4 line 29-32)"""
        if len(self.memoryBuffer) < self.MiniBatchSize:
            return False

        self.nnWUNN.train()
        optimizer = optim.Adam(self.nnWUNN.parameters())
        early_stop = False

        for iter in range(self.MaxTrainIter):
            # Compute loss on entire buffer periodically
            if iter % 10 == 0:
                all_low_uncertainty = True
                for x, _ in list(self.memoryBuffer)[
                    :100
                ]:  # Check subset for efficiency
                    sigma2_e = self.nnWUNN.predict_sigma_e(
                        x, 10
                    )  # Smaller K for faster checking
                    if sigma2_e >= self.kappa * self.epsilon:
                        all_low_uncertainty = False
                        break

                if all_low_uncertainty:
                    early_stop = True
                    break

            # Mini-batch training
            batch = random.sample(
                self.memoryBuffer, min(self.MiniBatchSize, len(self.memoryBuffer))
            )
            total_loss = 0

            for x, y in batch:
                x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(0)
                y_tensor = torch.tensor([y], dtype=torch.float32)

                # Forward pass with multiple samples
                preds = torch.stack(
                    [self.nnWUNN.forward_single(x_tensor) for _ in range(self.nnWUNN.S)]
                )

                # Compute loss
                log_likelihood = -F.mse_loss(preds.mean(), y_tensor)
                kl_div = (
                    self.nnWUNN.fc1.kl_divergence() + self.nnWUNN.fc2.kl_divergence()
                )
                loss = self.beta * kl_div - log_likelihood

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

        return early_stop

    def run(self):
        """Main learning loop with strict β decay schedule"""
        self.yq = -np.inf

        for n in range(1, self.NumIter + 1):
            print(f"\n=== Iteration {n}/{self.NumIter} ===")
            print(f"Current β: {self.beta:.6f}")  # Track β precisely

            # Update yq from memory buffer
            if self.memoryBuffer:
                costs = [y for _, y in self.memoryBuffer]
                self.yq = np.quantile(costs, self.q)
            print(f"Current yq (q={self.q}): {self.yq:.2f}, α: {self.alpha:.3f}")

            # Generate and solve tasks
            numSolved = 0
            for i in range(self.NumTasksPerIter):
                T = self.generate_task()
                if not T:
                    print("No task generated (low uncertainty)")
                    continue

                try:
                    start_time = time.time()
                    plan = self.ida_star(
                        T["s"], T["sg"], 
                        self.uncertainty_aware_heuristic,
                        self.tmax, start_time
                    )
                    if plan:
                        numSolved += 1
                        plan_cost = len(plan) - 1
                        optimal_cost = manhattan_distance(T["s"], T["sg"])
                        self.planner_costs.append(plan_cost)
                        self.optimal_costs.append(optimal_cost)
                        if plan_cost == optimal_cost:
                            self.optimality_counts += 1
                        print(f"✓ Task {i+1}: cost={plan_cost}, optimal={optimal_cost}")
                        for state in reversed(plan[:-1]):
                            x = encode_15puzzle_state(state)
                            y = manhattan_distance(state, T["sg"])
                            self.memoryBuffer.appendleft((x, y))
                except TimeoutError:
                    print(f"⏳ Task {i+1} timed out")

            # Update α (conditionally)
            if numSolved < self.NumTasksPerIterThresh:
                self.alpha = max(self.alpha - self.delta, 0.5)
                print(f"Reduced α to {self.alpha:.3f} (solved {numSolved} tasks)")

            # Train models
            print("Training models...")
            self.train_ffnn()
            _ = self.train_wunn()  # early_stop ignored

            # Strict β decay (unconditional)
            self.beta *= self.gamma
            print(f"Decayed β to {self.beta:.6f} (γ={self.gamma:.6f})")

            # Log metrics
            avg_subopt, opt_rate = self.compute_metrics()
            print(f"Iteration {n} results:")
            print(f"  Solved: {numSolved}/{self.NumTasksPerIter}")
            print(f"  Suboptimality: {avg_subopt:.3f}")
            print(f"  Optimality Rate: {opt_rate:.1f}%")


# --- Example Usage ---
if __name__ == "__main__":
    goal_state = list(range(16))
    input_dim = len(encode_15puzzle_state(goal_state))

    # Algorithm parameters
    params = {
        "hidden_dim": 20,
        "dropout_rate": 0.1,
        "alpha0": 0.99,
        "beta0": 0.05,  # Will decay to 0.00001 in NumIter steps
        "epsilon": 1.0,
        "delta": 0.05,
        "kappa": 0.64,
        # gamma is now computed automatically in __init__
        "q": 0.95,
        "K": 100,
        "MaxSteps": 15,
        "mu0": 0.0,
        "sigma2_0": 10.0,
        "NumIter": 50,  # Will decay β from 0.05 to 0.00001 in 50 steps
        "NumTasksPerIter": 10,
        "NumTasksPerIterThresh": 6,
        "TrainIter": 200,
        "MaxTrainIter": 1000,
        "MiniBatchSize": 32,
        "tmax": 30,
        "MemoryBufferMaxRecords": 25000,
    }

    learner = LearnHeuristicPrac(input_dim, goal_state, params)
    learner.run()

    # Final metrics
    avg_subopt, opt_rate = learner.compute_metrics()
    print(f"\n=== Final Metrics ===")
    print(f"Average Suboptimality: {avg_subopt:.3f}")
    print(f"Optimality Rate: {opt_rate:.1f}%")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import math
import time
from scipy.stats import norm
from collections import deque, defaultdict


# --- Environment and Utility Functions ---
def encode_15puzzle_state(state):
    """Encodes the 15-puzzle state into a feature vector"""
    encoded = np.zeros(16 * 2 * 4)
    for tile in range(16):
        idx = state.index(tile)
        row, col = divmod(idx, 4)
        encoded[tile * 8 + row] = 1
        encoded[tile * 8 + 4 + col] = 1
    return encoded


def get_valid_moves(state):
    """Returns all valid moves from the current state"""
    zero_index = state.index(0)
    row, col = divmod(zero_index, 4)
    valid_moves = []
    directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    for dr, dc in directions:
        new_row, new_col = row + dr, col + dc
        if 0 <= new_row < 4 and 0 <= new_col < 4:
            new_zero_index = new_row * 4 + new_col
            new_state = state[:]
            new_state[zero_index], new_state[new_zero_index] = (
                new_state[new_zero_index],
                new_state[zero_index],
            )
            valid_moves.append(new_state)
    return valid_moves


def manhattan_distance(state, goal_state):
    """Computes Manhattan distance heuristic"""
    total = 0
    for i in range(1, 16):
        curr_idx = state.index(i)
        goal_idx = goal_state.index(i)
        curr_row, curr_col = divmod(curr_idx, 4)
        goal_row, goal_col = divmod(goal_idx, 4)
        total += abs(curr_row - goal_row) + abs(curr_col - goal_col)
    return total


# --- Neural Network Definitions ---
class BayesianLinear(nn.Module):
    """Bayesian linear layer with local reparameterization"""

    def __init__(self, in_features, out_features, prior_mu=0.0, prior_sigma=1.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Initialize parameters with given priors
        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
        self.weight_logvar = nn.Parameter(torch.empty(out_features, in_features))
        self.bias_mu = nn.Parameter(torch.empty(out_features))
        self.bias_logvar = nn.Parameter(torch.empty(out_features))

        # Initialize with prior values
        nn.init.normal_(self.weight_mu, mean=prior_mu, std=prior_sigma / 10)
        nn.init.constant_(self.weight_logvar, math.log(prior_sigma**2))
        nn.init.normal_(self.bias_mu, mean=prior_mu, std=prior_sigma / 10)
        nn.init.constant_(self.bias_logvar, math.log(prior_sigma**2))

        # Prior parameters
        self.prior_mu = prior_mu
        self.prior_sigma = prior_sigma

    def forward(self, x):
        # Local reparameterization trick
        weight_var = torch.exp(self.weight_logvar)
        bias_var = torch.exp(self.bias_logvar)

        mu_out = F.linear(x, self.weight_mu, self.bias_mu)
        var_out = F.linear(x.pow(2), weight_var, bias_var)

        eps = torch.randn_like(mu_out)
        return mu_out + torch.sqrt(var_out + 1e-8) * eps

    def kl_divergence(self):
        """Computes KL divergence between posterior and prior"""
        kl_weight = 0.5 * (
            (self.weight_mu - self.prior_mu).pow(2) + torch.exp(self.weight_logvar)
        ) / (self.prior_sigma**2) - 0.5 * (
            1 + self.weight_logvar - math.log(self.prior_sigma**2)
        )
        kl_bias = 0.5 * (
            (self.bias_mu - self.prior_mu).pow(2) + torch.exp(self.bias_logvar)
        ) / (self.prior_sigma**2) - 0.5 * (
            1 + self.bias_logvar - math.log(self.prior_sigma**2)
        )
        return kl_weight.sum() + kl_bias.sum()


class WUNN(nn.Module):
    """Weight Uncertainty Neural Network for epistemic uncertainty"""

    def __init__(self, input_dim, hidden_dim=20, S=5, prior_mu=0.0, prior_sigma=1.0):
        super().__init__()
        self.S = S  # Number of samples for forward pass
        self.fc1 = BayesianLinear(input_dim, hidden_dim, prior_mu, prior_sigma)
        self.fc2 = BayesianLinear(hidden_dim, 1, prior_mu, prior_sigma)

    def forward_single(self, x):
        """Single forward pass"""
        x = F.relu(self.fc1(x))
        return self.fc2(x)

    def predict_sigma_e(self, x, K=100):
        """Predicts epistemic uncertainty"""
        self.eval()
        x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            outputs = [self.forward_single(x_tensor).item() for _ in range(K)]
        return np.var(outputs)


class FFNN(nn.Module):
    """Feedforward Neural Network for aleatoric uncertainty"""

    def __init__(self, input_dim, hidden_dim=20, dropout_rate=0.025):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2_mean = nn.Linear(hidden_dim, 1)
        self.fc2_var = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(dropout_rate)

        # Initialize weights
        nn.init.kaiming_normal_(self.fc1.weight, mode="fan_in", nonlinearity="relu")
        nn.init.kaiming_normal_(self.fc2_mean.weight, mode="fan_in", nonlinearity="relu")
        nn.init.kaiming_normal_(self.fc2_var.weight, mode="fan_in", nonlinearity="relu")
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2_mean.bias)
        nn.init.zeros_(self.fc2_var.bias)

    def forward(self, x):
        """Forward pass returning mean and log variance"""
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2_mean(x), self.fc2_var(x)

    def predict(self, x):
        """Predicts mean and variance"""
        self.eval()
        with torch.no_grad():
            mean, logvar = self.forward(x)
            return mean.item(), torch.exp(logvar).item()


# --- Main Algorithm Implementation ---
class LearnHeuristicPrac:
    def __init__(self, input_dim, goal_state, params):
        # Initialize models with prior parameters
        self.nnWUNN = WUNN(
            input_dim,
            params["hidden_dim"],
            prior_mu=params["mu0"],
            prior_sigma=math.sqrt(params["sigma2_0"]),
        )
        self.nnFFNN = FFNN(input_dim, params["hidden_dim"], params["dropout_rate"])        

        # Algorithm parameters
        self.alpha = params["alpha0"]
        self.beta = params["beta0"]
        self.epsilon = params["epsilon"]
        self.delta = params["delta"]
        self.kappa = params["kappa"]
        self.gamma = (0.00001 / params["beta0"]) ** (1 / params["NumIter"])  # Key change: Compute γ
        self.q = params["q"]
        self.K = params["K"]
        self.max_steps = params["MaxSteps"]

        # Training parameters
        self.NumIter = params["NumIter"]
        self.NumTasksPerIter = params["NumTasksPerIter"]
        self.NumTasksPerIterThresh = params["NumTasksPerIterThresh"]
        self.TrainIter = params["TrainIter"]
        self.MaxTrainIter = params["MaxTrainIter"]
        self.MiniBatchSize = params["MiniBatchSize"]
        self.tmax = params["tmax"]

        # Memory buffer
        self.memoryBuffer = deque(maxlen=params["MemoryBufferMaxRecords"])

        # Metrics tracking
        self.planner_costs = []
        self.optimal_costs = []
        self.suboptimalities = []
        self.optimality_counts = 0
        self.goal_state = goal_state

    def h(self, alpha, mu, sigma):
        """Quantile function of normal distribution"""
        return mu + sigma * norm.ppf(alpha)

    def generate_task(self):
        """Generates a task with high epistemic uncertainty (Algorithm 3)"""
        s_prime = self.goal_state[:]
        s_double_prime = None

        for _ in range(self.max_steps):
            states = {}
            valid_moves = get_valid_moves(s_prime)

            for s in valid_moves:
                if s_double_prime is not None and s == s_double_prime:
                    continue

                x = encode_15puzzle_state(s)
                sigma2_e = self.nnWUNN.predict_sigma_e(x, self.K)
                states[tuple(s)] = sigma2_e  # Use tuple as dict key

            if not states:
                break

            # Softmax sampling
            states_list = list(states.items())
            state_tuples, sigmas = zip(*states_list)
            probs = F.softmax(torch.tensor(sigmas), dim=0).numpy()
            selected_idx = np.random.choice(len(state_tuples), p=probs)
            selected_state = list(state_tuples[selected_idx])
            selected_sigma = sigmas[selected_idx]

            if selected_sigma >= self.epsilon:
                return {
                    "s": selected_state,
                    "sg": self.goal_state,
                    "sigma2_e": selected_sigma,
                }

            s_double_prime = s_prime
            s_prime = selected_state

        return None

    def ida_star(self, start, goal, heuristic, tmax, start_time):
        """IDA* implementation with node counting"""
        threshold = heuristic(start)
        path = [start]
        total_nodes = 0  # Counter for nodes generated

        def search(g, bound):
            nonlocal total_nodes
            if time.time() - start_time > tmax:
                raise TimeoutError()
            
            node = path[-1]
            f = g + heuristic(node)
            if f > bound:
                return f
            if node == goal:
                return True
            
            min_t = float('inf')
            neighbors = get_valid_moves(node)
            total_nodes += len(neighbors)  # Count all generated nodes
            
            for neighbor in neighbors:
                if neighbor in path:
                    continue
                path.append(neighbor)
                t = search(g + 1, bound)
                if t is True:
                    return True
                if t < min_t:
                    min_t = t
                path.pop()
            return min_t

        while True:
            t = search(0, threshold)
            if t is True:
                return path, total_nodes  # Return both path and node count
            if t == float('inf'):
                return None, total_nodes
            threshold = t

    def uncertainty_aware_heuristic(self, state):
        """Computes h(s) = max(h(α, ŷ(s), σ_t(s)), 0)"""
        x = encode_15puzzle_state(state)
        x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(0)

        # Get FFNN predictions
        self.nnFFNN.eval()
        with torch.no_grad():
            mean, logvar = self.nnFFNN(x_tensor)
            y_hat = mean.item()
            sigma2_a = torch.exp(logvar).item()

        # Determine which variance to use
        sigma2_t = sigma2_a if y_hat < self.yq else self.epsilon

        # Compute heuristic value
        h_val = self.h(self.alpha, y_hat, math.sqrt(sigma2_t))
        return max(h_val, 0)

    def compute_metrics(self):
        """Compute suboptimality and optimality metrics"""
        if not self.planner_costs:
            return 0.0, 0.0

        # Calculate suboptimality (u_i)
        suboptimalities = [
            (y / y_star) - 1
            for y, y_star in zip(self.planner_costs, self.optimal_costs)
            if y_star > 0  # Avoid division by zero
        ]
        avg_suboptimality = (
            sum(suboptimalities) / len(suboptimalities) if suboptimalities else 0.0
        )

        # Calculate optimality rate (% tasks solved optimally)
        optimality_rate = (
            (self.optimality_counts / len(self.planner_costs)) * 100
            if self.planner_costs
            else 0.0
        )

        return avg_suboptimality, optimality_rate

    def train_ffnn(self):
        """Trains FFNN on entire memory buffer"""
        if len(self.memoryBuffer) < self.MiniBatchSize:
            return

        optimizer = optim.Adam(self.nnFFNN.parameters())
        criterion = nn.GaussianNLLLoss()

        # Convert memory buffer to tensors
        x_data = torch.stack(
            [torch.tensor(x, dtype=torch.float32) for x, _ in self.memoryBuffer]
        )
        y_data = torch.tensor(
            [y for _, y in self.memoryBuffer], dtype=torch.float32
        ).unsqueeze(1)

        self.nnFFNN.train()
        for _ in range(self.TrainIter):
            # Shuffle and batch the data
            permutation = torch.randperm(len(x_data))
            for i in range(0, len(x_data), self.MiniBatchSize):
                indices = permutation[i : i + self.MiniBatchSize]
                x_batch, y_batch = x_data[indices], y_data[indices]

                optimizer.zero_grad()
                mean, logvar = self.nnFFNN(x_batch)
                loss = criterion(mean, y_batch, torch.exp(logvar))
                loss.backward()
                optimizer.step()

    def train_wunn(self):
        """Trains WUNN with prioritized sampling and early stopping."""
        if len(self.memoryBuffer) < self.MiniBatchSize:
            return False

        self.nnWUNN.train()
        optimizer = optim.Adam(self.nnWUNN.parameters(), lr=0.01)
        early_stop = False

        # Precompute epistemic uncertainties for the entire buffer
        uncertainties = []
        for x, _ in self.memoryBuffer:
            sigma2_e = self.nnWUNN.predict_sigma_e(x, K=10)  # Approximate σ²_e
            uncertainties.append(sigma2_e)

        # Compute sampling weights
        weights = []
        for sigma2_e in uncertainties:
            if sigma2_e >= self.kappa * self.epsilon:
                weight = math.exp(math.sqrt(sigma2_e))  # exp(σ_e)
            else:
                weight = math.exp(-1)  # C=1
            weights.append(weight)

        for iter in range(self.MaxTrainIter):
            # Early stopping check (unchanged)
            if iter % 10 == 0:
                all_low_uncertainty = True
                for x, _ in list(self.memoryBuffer)[:100]:  # Check subset
                    sigma2_e = self.nnWUNN.predict_sigma_e(x, 10)
                    if sigma2_e >= self.kappa * self.epsilon:
                        all_low_uncertainty = False
                        break
                if all_low_uncertainty:
                    early_stop = True
                    break

            # --- Prioritized Sampling ---
            # Sample indices based on weights
            batch_indices = random.choices(
                range(len(self.memoryBuffer)),
                weights=weights,
                k=min(self.MiniBatchSize, len(self.memoryBuffer))
            )
            batch = [self.memoryBuffer[i] for i in batch_indices]

            # Training loop (unchanged)
            total_loss = 0
            for x, y in batch:
                x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(0)
                y_tensor = torch.tensor([y], dtype=torch.float32)
                preds = torch.stack(
                    [self.nnWUNN.forward_single(x_tensor) for _ in range(self.nnWUNN.S)]
                )
                log_likelihood = -F.mse_loss(preds.mean(), y_tensor)
                kl_div = self.nnWUNN.fc1.kl_divergence() + self.nnWUNN.fc2.kl_divergence()
                loss = self.beta * kl_div - log_likelihood
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

        return early_stop

    def run(self):
        """Main learning loop with strict β decay schedule"""
        self.yq = -np.inf

        for n in range(1, self.NumIter + 1):
            print(f"\n=== Iteration {n}/{self.NumIter} ===")
            print(f"Current β: {self.beta:.6f}")  # Track β precisely

            # Update yq from memory buffer
            if self.memoryBuffer:
                costs = [y for _, y in self.memoryBuffer]
                self.yq = np.quantile(costs, self.q)
            print(f"Current yq (q={self.q}): {self.yq:.2f}, α: {self.alpha:.3f}")

            # Generate and solve tasks
            numSolved = 0
            for i in range(self.NumTasksPerIter):
                T = self.generate_task()
                if not T:
                    print("No task generated (low uncertainty)")
                    continue

                try:
                    start_time = time.time()
                    plan = self.ida_star(
                        T["s"], T["sg"], 
                        self.uncertainty_aware_heuristic,
                        self.tmax, start_time
                    )
                    if plan:
                        numSolved += 1
                        plan_cost = len(plan) - 1
                        optimal_cost = manhattan_distance(T["s"], T["sg"])
                        self.planner_costs.append(plan_cost)
                        self.optimal_costs.append(optimal_cost)
                        if plan_cost == optimal_cost:
                            self.optimality_counts += 1
                        print(f"✓ Task {i+1}: cost={plan_cost}, optimal={optimal_cost}")
                        for state in reversed(plan[:-1]):
                            x = encode_15puzzle_state(state)
                            y = manhattan_distance(state, T["sg"])
                            self.memoryBuffer.appendleft((x, y))
                except TimeoutError:
                    print(f"⏳ Task {i+1} timed out")

            # Update α (conditionally)
            if numSolved < self.NumTasksPerIterThresh:
                self.alpha = max(self.alpha - self.delta, 0.5)
                print(f"Reduced α to {self.alpha:.3f} (solved {numSolved} tasks)")

            # Train models
            print("Training models...")
            self.train_ffnn()
            _ = self.train_wunn()  # early_stop ignored

            # Strict β decay (unconditional)
            self.beta *= self.gamma
            print(f"Decayed β to {self.beta:.6f} (γ={self.gamma:.6f})")

            # Log metrics
            avg_subopt, opt_rate = self.compute_metrics()
            print(f"Iteration {n} results:")
            print(f"  Solved: {numSolved}/{self.NumTasksPerIter}")
            print(f"  Suboptimality: {avg_subopt:.3f}")
            print(f"  Optimality Rate: {opt_rate:.1f}%")


# --- Example Usage ---
if __name__ == "__main__":
    goal_state = list(range(16))
    input_dim = len(encode_15puzzle_state(goal_state))

    # Algorithm parameters
    params = {
        "hidden_dim": 20,
        "dropout_rate": 0.025,
        "alpha0": 0.99,
        "beta0": 0.05,  # Will decay to 0.00001 in NumIter steps
        "epsilon": 1.0,
        "delta": 0.05,
        "kappa": 0.64,
        # gamma is now computed automatically in __init__
        "q": 0.95,
        "K": 100,
        "MaxSteps": 1000,
        "mu0": 0.0,
        "sigma2_0": 10.0,
        "NumIter": 50,  # Will decay β from 0.05 to 0.00001 in 50 steps
        "NumTasksPerIter": 10,
        "NumTasksPerIterThresh": 6,
        "TrainIter": 1000,
        "MaxTrainIter": 5000,
        "MiniBatchSize": 100,
        "tmax": 60,
        "MemoryBufferMaxRecords": 25000,
    }

    learner = LearnHeuristicPrac(input_dim, goal_state, params)
    learner.run()

    # Final metrics
    avg_subopt, opt_rate = learner.compute_metrics()
    print(f"\n=== Final Metrics ===")
    print(f"Average Suboptimality: {avg_subopt:.3f}")
    print(f"Optimality Rate: {opt_rate:.1f}%")
