# Исследование гибридных методов решения симметричной задачи коммивояжера для оптимизации подбора параметров технологической системы
Студент группы ВМ-223: Баринов Даниил Сергеевич

## Библиотеки

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
import os

import numpy as np

import time

from sklearn.utils import shuffle
from scipy.spatial.distance import pdist, squareform
from sklearn.utils.class_weight import compute_class_weight

import matplotlib.pyplot as plt

import networkx as nx

import numpy as np
import pandas as pd
import logging
import networkx as nx
import math
import matplotlib.pyplot as plt

import time
import urllib

import os

import signal
import gzip
import time
import shutil
from pathlib import Path
from typing import Tuple, List, Optional, Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
from sklearn.utils import shuffle
from sklearn.utils.class_weight import compute_class_weight
from scipy.spatial.distance import pdist, squareform
import time

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Логирование

In [None]:
def configure_logger(level='INFO', logfile=None):
    """
    Configures the logger with the specified log level and log file.

    Args:
        level (str): Log level ('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL').
        logfile (str): Path to the log file.

    Returns:
        None
    """
    logger = logging.getLogger("console")
    logger.setLevel(level)

    if not logger.handlers:
        if not logfile:
            console_handler = logging.StreamHandler()
            console_handler.setLevel(level)
            logger.addHandler(console_handler)
        else:
            file_handler = logging.FileHandler(logfile)
            file_handler.setLevel(level)
            logger.addHandler(file_handler)


def log(message, level='INFO'):
    """
    Logs a message with the specified log level.

    Args:
        message (str): Message to be logged.
        level (str): Log level ('DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL').

    Returns:
        None
    """
    logger = logging.getLogger("console")

    match level:
        case 'DEBUG':
            logger.debug(message)
        case 'INFO':
            logger.info(message)
        case 'WARNING':
            logger.warning(message)
        case 'ERROR':
            logger.error(message)
        case 'CRITICAL':
            logger.critical(message)
        case _:
            logger.debug(message)

## Конфигурация бенчмарков

Handler для таймаутов

In [None]:
class TimeoutException(Exception):   # Custom exception class
    pass

def handle_alarm(signal_number, frame):
    raise TimeoutException

signal.signal(signal.SIGALRM, handle_alarm)

signal_time = 300

times_info = {}
shortest_paths = {}

TSP Benchmark класс для TSPLIB и гибридных методов

In [None]:
class TSPBenchmark:
    def __init__(self, data_dir: str = "tsp_data"):
        """
        Initialize the TSP benchmark environment.

        Args:
            data_dir: Directory to store/load the TSP datasets
        """
        self.data_dir = Path(data_dir)
        self.data_dir.mkdir(exist_ok=True)
        self.instances = {}  # Maps instance name to metadata
        self.optimal_tours = {}  # Maps instance name to optimal tour

    def download_instance(self, instance_name: str, url: str) -> bool:
        """
        Download a TSPLIB instance if not already present.
        Handles both plain files and gzipped files (.gz).

        Args:
            instance_name: Name of the instance (e.g., 'a280')
            url: Base URL to download from (will be modified for .gz files)

        Returns:
            True if successful, False otherwise
        """
        instance_dir = self.data_dir / instance_name
        instance_dir.mkdir(exist_ok=True)

        tsp_file = instance_dir / f"{instance_name}.tsp"
        opt_file = instance_dir / f"{instance_name}.opt.tour"

        # Handle gzipped files
        gz_url = url + ".gz"  # Append .gz to the URL

        # Download .tsp file if not exists
        if not tsp_file.exists():
            try:
                print(f"Downloading {instance_name}.tsp.gz...")
                # Download to a temporary file
                temp_gz_file = instance_dir / f"{instance_name}.tsp.gz"
                urllib.request.urlretrieve(gz_url, temp_gz_file)

                # Extract the gzipped file
                with gzip.open(temp_gz_file, 'rb') as f_in:
                    with open(tsp_file, 'wb') as f_out:
                        shutil.copyfileobj(f_in, f_out)

                # Remove the temporary file
                temp_gz_file.unlink()
                print(f"Successfully downloaded and extracted {instance_name}.tsp")
            except Exception as e:
                print(f"Error downloading {gz_url}: {e}")

                # Try without .gz extension as fallback
                try:
                    print(f"Trying without .gz extension...")
                    urllib.request.urlretrieve(url, tsp_file)
                    print(f"Successfully downloaded {instance_name}.tsp")
                except Exception as e2:
                    print(f"Error downloading {url}: {e2}")
                    return False

        # Try to download optimal tour if available
        opt_url = url.replace(".tsp", ".opt.tour") + ".gz"
        if not opt_file.exists():
            try:
                print(f"Downloading {instance_name}.opt.tour.gz...")
                # Download to a temporary file
                temp_gz_file = instance_dir / f"{instance_name}.opt.tour.gz"
                urllib.request.urlretrieve(opt_url, temp_gz_file)

                # Extract the gzipped file
                with gzip.open(temp_gz_file, 'rb') as f_in:
                    with open(opt_file, 'wb') as f_out:
                        shutil.copyfileobj(f_in, f_out)

                # Remove the temporary file
                temp_gz_file.unlink()
                print(f"Successfully downloaded and extracted optimal tour for {instance_name}")
            except Exception as e:
                print(f"Optimal tour not available for {instance_name}: {e}")

                # Try without .gz extension as fallback
                try:
                    print(f"Trying without .gz extension...")
                    opt_url_plain = url.replace(".tsp", ".opt.tour")
                    urllib.request.urlretrieve(opt_url_plain, opt_file)
                    print(f"Successfully downloaded optimal tour for {instance_name}")
                except Exception as e2:
                    print(f"Optimal tour not available (tried plain file): {e2}")
                    # This is not a fatal error, as some instances might not have published optimal tours

        return True

    def download_tsplib_set(self, size_range: Tuple[int, int] = (100, 1000)) -> None:
        """
        Download a set of TSPLIB instances within a size range.

        Args:
            size_range: Tuple of (min_cities, max_cities)
        """
        # TSPLIB index URL
        tsplib_index = "http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/tsp/"

        # Download a few popular instances
        instances = {
            "berlin52": "http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/tsp/berlin52.tsp",
            "eil101": "http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/tsp/eil101.tsp",
            "ch130": "http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/tsp/ch130.tsp",
            "ch150": "http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/tsp/ch150.tsp",
            # "brg180": "http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/tsp/brg180.tsp",
            "a280": "http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/tsp/a280.tsp",
            "pcb442": "http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/tsp/pcb442.tsp",
            # "pr1002": "http://comopt.ifi.uni-heidelberg.de/software/TSPLIB95/tsp/pr1002.tsp"
        }

        for name, url in instances.items():
            self.download_instance(name, url)

    # The rest of the class remains unchanged
    def load_instance(self, instance_name: str) -> Dict[str, Any]:
        """
        Load a TSP instance and parse its metadata.

        Args:
            instance_name: Name of the instance to load

        Returns:
            Dictionary with instance metadata and distance matrix
        """
        tsp_file = self.data_dir / instance_name / f"{instance_name}.tsp"

        if not tsp_file.exists():
            raise FileNotFoundError(f"Instance file {tsp_file} not found")

        metadata = {
            "name": instance_name,
            "coords": [],
            "dimension": 0,
            "edge_weight_type": "",
            "comment": ""
        }

        # Parse the .tsp file
        reading_coords = False
        with open(tsp_file, 'r') as f:
            for line in f:
                line = line.strip()

                if line == "EOF":
                    break

                if reading_coords:
                    if line:
                        parts = line.split()
                        # Skip the index (first part) and parse the coordinates
                        coords = [float(p) for p in parts[1:3]]
                        metadata["coords"].append(coords)
                    continue

                if ":" in line:
                    key, value = [p.strip() for p in line.split(":", 1)]
                    if key == "DIMENSION":
                        metadata["dimension"] = int(value)
                    elif key == "EDGE_WEIGHT_TYPE":
                        metadata["edge_weight_type"] = value
                    elif key == "COMMENT":
                        metadata["comment"] = value

                if line == "NODE_COORD_SECTION":
                    reading_coords = True

        # Convert coordinates to a NumPy array
        metadata["coords"] = np.array(metadata["coords"])

        # Calculate distance matrix
        n = metadata["dimension"]
        distances = np.zeros((n, n))
        if metadata["edge_weight_type"] == "EUC_2D":
            # Euclidean distance (rounded to nearest integer as per TSPLIB standard)
            for i in range(n):
                for j in range(n):
                    if i != j:
                        dx = metadata["coords"][i][0] - metadata["coords"][j][0]
                        dy = metadata["coords"][i][1] - metadata["coords"][j][1]
                        distances[i][j] = round(math.sqrt(dx*dx + dy*dy))
        elif metadata["edge_weight_type"] == "GEO":
            # Geographical distance (latitude/longitude)
            # Implement GEO distance calculation if needed
            pass
        else:
            # Default to Euclidean distance (not rounded)
            for i in range(n):
                for j in range(n):
                    if i != j:
                        distances[i][j] = np.linalg.norm(metadata["coords"][i] - metadata["coords"][j])

        metadata["distances"] = distances

        # Try to load optimal tour if available
        opt_file = self.data_dir / instance_name / f"{instance_name}.opt.tour"
        if opt_file.exists():
            optimal_tour = self.load_optimal_tour(opt_file)
            metadata["optimal_tour"] = optimal_tour
            metadata["optimal_length"] = self.calculate_tour_length(
                optimal_tour, distances
            )
            print(f"Loaded optimal tour with length {metadata['optimal_length']}")
        else:
            print(f"No optimal tour available for {instance_name}")

        self.instances[instance_name] = metadata
        return metadata

    def load_optimal_tour(self, tour_file: Path) -> List[int]:
        """
        Load an optimal tour from a .tour file.

        Args:
            tour_file: Path to the .tour file

        Returns:
            List of city indices forming the optimal tour
        """
        tour = []
        reading_tour = False

        with open(tour_file, 'r') as f:
            for line in f:
                line = line.strip()

                if line == "EOF":
                    break

                if reading_tour:
                    if line and line.isdigit():
                        # TSPLIB tour files use 1-based indexing, convert to 0-based
                        tour.append(int(line) - 1)
                    continue

                if line == "TOUR_SECTION":
                    reading_tour = True

        return tour

    def calculate_tour_length(self, tour: List[int], distances: np.ndarray) -> float:
        """
        Calculate the length of a tour.

        Args:
            tour: List of city indices
            distances: Distance matrix

        Returns:
            Total tour length
        """
        length = 0
        for i in range(len(tour)):
            length += distances[tour[i]][tour[(i + 1) % len(tour)]]
        return length

    def benchmark_solver(self,
                         instance_name: str,
                         solver_class,
                         solver_name = None,
                         solver_params: Dict[str, Any] = None,
                         runs: int = 5) -> Dict[str, Any]:
        """
        Benchmark a TSP solver on a specific instance.

        Args:
            instance_name: Name of the instance to benchmark on
            solver_class: Class of the solver to benchmark
            solver_params: Parameters to pass to the solver
            runs: Number of runs to perform

        Returns:
            Dictionary with benchmark results
        """
        if instance_name not in self.instances:
            self.load_instance(instance_name)

        instance = self.instances[instance_name]

        if solver_params is None:
            solver_params = {}

        results = {
            "instance": instance_name,
            "dimension": instance["dimension"],
            "runs": [],
            "best_length": float('inf'),
            "worst_length": 0,
            "mean_length": 0,
            "mean_time": 0
        }

        total_length = 0
        total_time = 0

        for run in range(runs):
            print(f"Run {run+1}/{runs}...")

            # Initialize the solver
            if solver_class:
                solver = solver_class(instance["distances"])
            

            # Time the solution process
            

            # Call the appropriate solve method based on the solver class
            if solver_name == "TSP_2opt_SA_naive":
                start_time = time.time()
                tour, length = solver.solve_hybrid(**solver_params)
                solve_time = time.time() - start_time
            elif solver_name == "TSP_LKH_Hybrid_naive":
                start_time = time.time()
                tour, length = solver.solve(**solver_params)
                solve_time = time.time() - start_time
            elif solver_name == "TSP_2opt_SA_outer":
                start_time = time.time()
                tour, length = solver.solve(**solver_params)
                solve_time = time.time() - start_time
            elif solver_name == "TSP_2opt_3opt_SA":
                start_time = time.time()
                tour, length = solver.solve(**solver_params)
                solve_time = time.time() - start_time
            elif solver_name == "TSP_LKH_Hybrid_optimized":
                start_time = time.time()
                tour, length = solver.solve(**solver_params)
                solve_time = time.time() - start_time
            elif solver_name == "GNN_v1_BS":
                gnn_model_v1 = solver_params["model"]
                config = solver_params["config"]
                beam_size_gnn_pure = solver_params["beam_size_gnn_pure"]
                dtypeFloat = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
                dtypeLong = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor
                coords = np.array(instance["coords"])
                start_time = time.time()
                tour, length, _ = TSPComparison.pure_gnn_beam_search_solve(
                    coords, gnn_model_v1, config, beam_size_gnn_pure, dtypeFloat, dtypeLong
                )
                solve_time = time.time() - start_time
            elif solver_name == "GNN_v2_BS":
                gnn_model_v2 = solver_params["model"]
                config = solver_params["config"]
                beam_size_gnn_pure = solver_params["beam_size_gnn_pure"]
                dtypeFloat = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
                dtypeLong = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor
                coords = np.array(instance["coords"])
                start_time = time.time()
                tour, length, _ = TSPComparison.pure_gnn_beam_search_solve(
                    coords, gnn_model_v2, config, beam_size_gnn_pure, dtypeFloat, dtypeLong
                )
                solve_time = time.time() - start_time
            elif solver_name == "GNN_v1_BC":
                gnn_solver_v1 = solver_params["gnn_solver"]
                coords = np.array(instance["coords"])
                start_time = time.time()
                tour, length, _ = gnn_solver_v1.solve_tsp(coords)
                solve_time = time.time() - start_time
            else:
                # Generic solve method
                start_time = time.time()
                tour, length = solver.solve(**solver_params)
                solve_time = time.time() - start_time


            # Record results
            run_result = {
                "tour": tour,
                "length": length,
                "time": solve_time
            }

            results["runs"].append(run_result)
            total_length += length
            total_time += solve_time

            # Update best and worst
            if length < results["best_length"]:
                results["best_length"] = length
                results["best_tour"] = tour
            if length > results["worst_length"]:
                results["worst_length"] = length

        # Calculate means
        results["mean_length"] = total_length / runs
        results["mean_time"] = total_time / runs

        # Calculate accuracy if optimal tour is available
        if "optimal_length" in instance:
            optimal_length = instance["optimal_length"]
            results["optimal_length"] = optimal_length
            results["best_gap"] = (results["best_length"] - optimal_length) / optimal_length * 100
            results["mean_gap"] = (results["mean_length"] - optimal_length) / optimal_length * 100

            print(f"Optimal tour length: {optimal_length}")
            print(f"Best tour length: {results['best_length']} (gap: {results['best_gap']:.2f}%)")
            print(f"Mean tour length: {results['mean_length']} (gap: {results['mean_gap']:.2f}%)")
        else:
            print(f"Best tour length: {results['best_length']}")
            print(f"Mean tour length: {results['mean_length']}")

        print(f"Mean solution time: {results['mean_time']:.2f} seconds")

        return results

    def benchmark_all(self,
                      solvers: List[Tuple[str, Any, Dict[str, Any]]],
                      instances: List[str] = None,
                      runs: int = 3) -> Dict[str, Dict[str, Any]]:
        """
        Benchmark multiple solvers on multiple instances.

        Args:
            solvers: List of (solver_name, solver_class, solver_params) tuples
            instances: List of instance names to benchmark on (if None, use all loaded instances)
            runs: Number of runs per solver per instance

        Returns:
            Dictionary mapping (instance_name, solver_name) to benchmark results
        """
        if instances is None:
            # Use all instances that have been loaded
            instances = list(self.instances.keys())

            # If no instances are loaded, download and load some standard ones
            if not instances:
                self.download_tsplib_set()
                instances = ["berlin52", "a280", "pcb442"]
                for instance in instances:
                    self.load_instance(instance)

        results = {}

        for instance_name in instances:
            print(f"\n{'='*50}")
            print(f"Benchmarking on {instance_name}")
            print(f"{'='*50}")

            for solver_name, solver_class, solver_params in solvers:
                print(f"\n{'-'*40}")
                print(f"Using solver: {solver_name}")
                print(f"{'-'*40}")

                key = (instance_name, solver_name)
                results[key] = self.benchmark_solver(
                    instance_name,
                    solver_class,
                    solver_name,
                    solver_params,
                    runs
                )

                # Add solver information
                results[key]["solver"] = solver_name
                results[key]["params"] = solver_params

        # Print summary
        self.print_benchmark_summary(results)

        return results

    def print_benchmark_summary(self, results: Dict[Tuple[str, str], Dict[str, Any]]) -> None:
        """
        Print a summary of benchmark results.

        Args:
            results: Dictionary mapping (instance_name, solver_name) to benchmark results
        """
        print("\n\n")
        print("="*80)
        print("BENCHMARK SUMMARY")
        print("="*80)

        # Group by instance
        by_instance = {}
        for (instance, solver), result in results.items():
            if instance not in by_instance:
                by_instance[instance] = []
            by_instance[instance].append((solver, result))

        # Print results for each instance
        for instance, solver_results in by_instance.items():
            print(f"\nInstance: {instance} ({self.instances[instance]['dimension']} cities)")
            print("-" * 80)
            print(f"{'Solver':<20} {'Best Length':<15} {'Mean Length':<15} {'Mean Time (s)':<15} {'Gap (%)':<10}")
            print("-" * 80)

            optimal_length = None
            if "optimal_length" in self.instances[instance]:
                optimal_length = self.instances[instance]["optimal_length"]

            for solver, result in sorted(solver_results, key=lambda x: x[1]["best_length"]):
                if optimal_length is not None:
                    gap = (result["best_length"] - optimal_length) / optimal_length * 100
                    gap_str = f"{gap:.2f}%"
                else:
                    gap_str = "N/A"

                print(f"{solver:<20} {result['best_length']:<15.2f} {result['mean_length']:<15.2f} {result['mean_time']:<15.2f} {gap_str:<10}")

            if optimal_length is not None:
                print(f"\nOptimal tour length: {optimal_length}")

    def visualize_tour(self, instance_name: str, tour: List[int], title: str = None) -> None:
        """
        Visualize a tour for a given instance.

        Args:
            instance_name: Name of the instance
            tour: List of city indices representing the tour
            title: Title for the plot
        """
        try:
            import matplotlib.pyplot as plt

            if instance_name not in self.instances:
                self.load_instance(instance_name)

            instance = self.instances[instance_name]
            coords = instance["coords"]

            # Create a plot
            plt.figure(figsize=(10, 8))

            # Plot cities
            plt.scatter(coords[:, 0], coords[:, 1], c='blue', s=20)

            # Plot tour
            for i in range(len(tour)):
                city1 = tour[i]
                city2 = tour[(i + 1) % len(tour)]
                plt.plot([coords[city1, 0], coords[city2, 0]],
                         [coords[city1, 1], coords[city2, 1]],
                         'r-', alpha=0.7)

            if title:
                plt.title(title)
            else:
                plt.title(f"Tour for {instance_name}")

            plt.tight_layout()

            # Save the plot to a file
            output_dir = self.data_dir / "plots"
            output_dir.mkdir(exist_ok=True)
            plt.savefig(output_dir / f"{instance_name}_{int(time.time())}.png")

            plt.show()

        except ImportError:
            print("Matplotlib is required for visualization.")
            print("Install it with: pip install matplotlib")

## Визуализация

In [None]:
def plot_tsp(p, x_coord, W, W_val, W_target, title="default"):
    """
    Helper function to plot TSP tours.

    Args:
        p: Matplotlib figure/subplot axis (e.g., plt.gca()).
        x_coord: Coordinates of nodes (num_nodes, 2).
        W: Edge adjacency matrix (ignored if plotting a tour path).
        W_val: Edge values (distance) matrix (used by nx.Graph).
        W_target: One-hot matrix with 1s on edges to plot (e.g., tour edges).
        title: Title of figure/subplot.

    Returns:
        p: Updated figure/subplot axis.
    """

    def _edges_to_node_pairs(W_target_matrix):
        """Helper function to convert edge matrix into pairs of adjacent nodes."""
        pairs = []
        rows, cols = np.where(W_target_matrix == 1)
        # Avoid duplicates for undirected graphs, take only (i, j) where i < j
        for r, c in zip(rows, cols):
            if r < c:
                pairs.append((r, c))
        return pairs

    # Используем nx.Graph() вместо DiGraph, т.к. TSP тур неориентированный
    G = nx.Graph(W_val) # Передаем матрицу расстояний для информации о графе
    pos = dict(zip(range(len(x_coord)), x_coord.tolist())) # Позиции узлов

    # Получаем пары ребер из матрицы W_target
    target_pairs = _edges_to_node_pairs(W_target)

    colors = ['g'] + ['b'] * (len(x_coord) - 1) # Зеленый для 0, синий для остальных
    nx.draw_networkx_nodes(G, pos, node_color=colors, node_size=50, ax=p)

    # Рисуем только ребра тура
    nx.draw_networkx_edges(G, pos, edgelist=target_pairs, alpha=1, width=1.5, edge_color='r', ax=p)

    # Добавляем подписи узлов
    labels = {i: str(i) for i in range(len(x_coord))}
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=8, ax=p)

    p.set_title(title)
    p.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True) # Показываем оси
    p.set_xlabel("X Coordinate")
    p.set_ylabel("Y Coordinate")
    p.grid(True, linestyle='--', alpha=0.6)
    p.axis('equal') # Равные масштабы

    return p


# --- Основная функция plot_solution ---
def plot_solution(nodes_coord: np.ndarray,
                  tour: List[int],
                  title: str,
                  tour_length: float,
                  ax: Optional[plt.Axes] = None):
    """
    Визуализирует один маршрут TSP на заданной оси matplotlib.

    Args:
        nodes_coord: Координаты узлов (N, 2).
        tour: Список индексов узлов в порядке обхода (0-based).
        tour_length: Длина маршрута.
        title: Заголовок для графика.
        ax: Ось matplotlib для рисования. Если None, создается новая фигура/ось.
    """
    n = nodes_coord.shape[0]
    if not tour:
        print(f"Warning: Cannot plot empty tour for '{title}'.")
        if ax: ax.set_title(f"{title}\n(Empty Tour)")
        return

    if len(tour) != n:
         print(f"Warning: Tour length ({len(tour)}) != num_nodes ({n}) for '{title}'. Plotting partial.")

    if ax is None:
        # Если ось не передана, создаем новую фигуру и ось
        fig, ax = plt.subplots(figsize=(8, 8))

    # Создаем матрицу W_target для ребер тура
    W_tour = np.zeros((n, n))
    for i in range(len(tour)):
        u = tour[i]
        v = tour[(i + 1) % len(tour)]
        if 0 <= u < n and 0 <= v < n:
            W_tour[u, v] = 1
            W_tour[v, u] = 1
        else:
            print(f"Error: Invalid node index in tour for plotting. u={u}, v={v}")
            ax.set_title(f"{title}\n(Invalid Tour)")
            return

    # Создаем фиктивные матрицы для plot_tsp
    dist_matrix_dummy = squareform(pdist(nodes_coord))
    W_dummy = np.ones((n,n))
    np.fill_diagonal(W_dummy, 0)

    # Используем plot_tsp для рисования на переданной оси 'ax'
    plot_tsp(ax, nodes_coord, W_dummy, dist_matrix_dummy, W_tour,
             title=f"{title}\nLength: {tour_length:.4f}")

    # Выделяем начальную точку
    start_node_idx = tour[0]
    # Используем plot вместо scatter на оси ax для лучшего контроля
    ax.plot(nodes_coord[start_node_idx, 0], nodes_coord[start_node_idx, 1],
            marker='*', color='yellow', markersize=15, markeredgecolor='black', zorder=4, linestyle='None')

# ==============================================================================
# 2. Новая функция `plot_comparison` (для двух графиков)
# ==============================================================================

def plot_comparison(nodes_coord: np.ndarray,
                    found_tour: List[int],
                    found_length: float,
                    optimal_tour: List[int],
                    optimal_length: float,
                    instance_name: str,
                    method_name: str,
                    found_color: str = 'red',
                    optimal_color: str = 'green',
                    found_linestyle: str = '-',
                    optimal_linestyle: str = '--',
                    found_alpha: float = 0.7,
                    optimal_alpha: float = 0.9,
                    start_node_color='yellow'):
    """
    Создает фигуру с двумя графиками: оптимальный тур слева, найденный тур справа.
    Не вызывает plot_solution, рисует все напрямую.

    Args:
        nodes_coord: Координаты узлов (N, 2).
        found_tour: Найденный маршрут (список индексов 0-based).
        found_length: Длина найденного маршрута.
        optimal_tour: Оптимальный маршрут (список индексов 0-based).
        optimal_length: Длина оптимального маршрута.
        instance_name: Имя экземпляра TSP.
        method_name: Имя метода, который нашел `found_tour`.
        # ... (остальные параметры для настройки цветов и т.д.)
    """
    n = nodes_coord.shape[0]

    # --- Создаем фигуру с двумя осями ---
    fig, axes = plt.subplots(1, 2, figsize=(18, 8)) # Немного увеличим ширину

    # --- Функция для отрисовки одного тура на оси ---
    def draw_single_tour(ax, tour, length, title, tour_color, linestyle, alpha, start_node_color):
        # 1. Проверка тура
        if not tour:
             ax.set_title(f"{title}\n(Empty Tour Provided)")
             ax.text(0.5, 0.5, "No tour data", ha='center', va='center', transform=ax.transAxes)
             ax.set_xticks([]); ax.set_yticks([]) # Убираем оси для пустого графика
             return
        if len(tour) != n:
             print(f"Warning: Tour length ({len(tour)}) != num_nodes ({n}) for '{title}'. Plotting partial.")
             # Мы все равно нарисуем то, что есть

        # 2. Рисуем узлы
        ax.scatter(nodes_coord[:, 0], nodes_coord[:, 1], c='blue', s=50, zorder=3, label='Cities')
        # ---> ИСПРАВЛЕНИЕ: Уменьшен отступ для текста <---
        text_offset = (np.max(nodes_coord[:,1]) - np.min(nodes_coord[:,1])) * 0.02 # Маленький % от диапазона Y
        for i, (x, y) in enumerate(nodes_coord):
            ax.text(x, y + text_offset, str(i), fontsize=9, ha='center', va='bottom') # Сдвиг вверх

        # 3. Рисуем линии тура
        tour_len_to_plot = len(tour)
        route_coords = nodes_coord[tour]
        # Замыкаем тур только если он полный
        route_coords_closed = np.vstack([route_coords, route_coords[0]]) if len(tour) == n else route_coords
        ax.plot(route_coords_closed[:, 0], route_coords_closed[:, 1],
                 color=tour_color, linestyle=linestyle, alpha=alpha,
                 lw=1.5, zorder=1, label=f'Length: {length:.4f}') # Увеличена точность длины

        # 4. Выделяем начальную точку
        start_node_idx = tour[0]
        if 0 <= start_node_idx < n:
            ax.plot(nodes_coord[start_node_idx, 0], nodes_coord[start_node_idx, 1],
                    marker='*', color=start_node_color, markersize=15,
                    markeredgecolor='black', zorder=4, linestyle='None',
                    label='Start Node (0)') # Добавим индекс в легенду

        # 5. Настройки осей
        ax.set_title(title)
        ax.set_xlabel('X Coordinate')
        ax.set_ylabel('Y Coordinate')
        ax.grid(True, linestyle='--', alpha=0.6)
        # ---> ИСПРАВЛЕНИЕ: Устанавливаем пределы осей ПОСЛЕ отрисовки <---
        # Добавляем небольшой отступ вокруг крайних точек
        margin_x = (np.max(nodes_coord[:, 0]) - np.min(nodes_coord[:, 0])) * 0.1
        margin_y = (np.max(nodes_coord[:, 1]) - np.min(nodes_coord[:, 1])) * 0.1
        ax.set_xlim(np.min(nodes_coord[:, 0]) - margin_x, np.max(nodes_coord[:, 0]) + margin_x)
        ax.set_ylim(np.min(nodes_coord[:, 1]) - margin_y, np.max(nodes_coord[:, 1]) + margin_y)
        ax.set_aspect('equal', adjustable='box') # <--- Используем set_aspect вместо axis('equal')
        ax.legend(loc='best', fontsize='small')

    # --- Рисуем на левой оси (Оптимальный) ---
    if optimal_tour:
        draw_single_tour(ax=axes[0], tour=optimal_tour, length=optimal_length,
                         title=f"Optimal Solution: {instance_name}",
                         tour_color=optimal_color, linestyle=optimal_linestyle,
                         alpha=optimal_alpha, start_node_color=start_node_color)
    else:
         axes[0].set_title(f"Optimal Solution: {instance_name}\n(Not Available)")
         axes[0].text(0.5, 0.5, "Optimal tour data missing", ha='center', va='center', transform=axes[0].transAxes)
         axes[0].set_xticks([]); axes[0].set_yticks([])

    # --- Рисуем на правой оси (Найденный) ---
    if found_tour:
        draw_single_tour(ax=axes[1], tour=found_tour, length=found_length,
                         title=f"Found ({method_name}): {instance_name}",
                         tour_color=found_color, linestyle=found_linestyle,
                         alpha=found_alpha, start_node_color=start_node_color)
    else:
         axes[1].set_title(f"Found ({method_name}): {instance_name}\n(Not Available)")
         axes[1].text(0.5, 0.5, "Found tour data missing", ha='center', va='center', transform=axes[1].transAxes)
         axes[1].set_xticks([]); axes[1].set_yticks([])

    # Общий заголовок
    gap = float('inf')
    gap_str = "N/A"
    if optimal_length > 0 and found_length != float('inf'):
        gap = (found_length - optimal_length) / optimal_length * 100
        gap_str = f"{gap:.2f}%"

    fig.suptitle(f"TSP Solution Comparison: {instance_name}\nMethod: {method_name} | Gap: {gap_str}", fontsize=16)

    plt.tight_layout(rect=[0, 0.03, 1, 0.93])
    plt.show()

# Графовые нейронные сети для решения задачи коммивояжера

Реализация архитектуры графовой нейронной сети (базовая, без интеграции с Branch&Cut) принадлежит авторам данной работы:

Isik S., Atkin M. Tackling the Traveling Salesman Problem with Graph Neural Networks. — 2023. — URL: https://medium.com/stanford-cs224w/tackling-the-traveling-salesman-problem-with-graph-neural-networks-b86ef4300c6e

В свою очередь, предложенная авторами архитектура вдохновлена следующими работами:

Joshi C. K., Laurent T., Bresson X. An Efficient Graph Convolutional Network Technique for the Travelling Salesman Problem. — 2019. — arXiv: 1906.01227 [cs.LG]. - URL: https://arxiv.org/abs/1906.01227

Bresson X., Laurent T. The Transformer Network for the Traveling Salesman Problem. — 2021. — arXiv: 2103.03012 [cs.LG]. - URL: https://arxiv.org/abs/2103.03012


## TSP Датасет

In [None]:
#@title Dataloader definitions
class DotDict(dict):
    """Wrapper around in-built dict class to access members through the dot operation.
    """

    def __init__(self, **kwds):
        self.update(kwds)
        self.__dict__ = self


class GoogleTSPReader(object):
    """Iterator that reads TSP dataset files and yields mini-batches.

    Format expected as in Vinyals et al., 2015: https://arxiv.org/abs/1506.03134, http://goo.gl/NDcOIG
    """

    def __init__(self, num_nodes, num_neighbors, batch_size, filepath):
        """
        Args:
            num_nodes: Number of nodes in TSP tours
            num_neighbors: Number of neighbors to consider for each node in graph
            batch_size: Batch size
            filepath: Path to dataset file (.txt file)
        """
        self.num_nodes = num_nodes
        self.num_neighbors = num_neighbors
        self.batch_size = batch_size
        self.filepath = filepath
        self.filedata = shuffle(open(filepath, "r").readlines())  # Always shuffle upon reading data
        self.max_iter = (len(self.filedata) // batch_size)

    def __iter__(self):
        for batch in range(self.max_iter):
            start_idx = batch * self.batch_size
            end_idx = (batch + 1) * self.batch_size
            yield self.process_batch(self.filedata[start_idx:end_idx])

    def process_batch(self, lines):
        """Helper function to convert raw lines into a mini-batch as a DotDict.
        """
        batch_edges = []
        batch_edges_values = []
        batch_edges_target = []  # Binary classification targets (0/1)
        batch_nodes = []
        batch_nodes_target = []  # Multi-class classification targets (`num_nodes` classes)
        batch_nodes_coord = []
        batch_tour_nodes = []
        batch_tour_len = []

        for line_num, line in enumerate(lines):
            line = line.split(" ")  # Split into list

            # Compute signal on nodes
            nodes = np.ones(self.num_nodes)  # All 1s for TSP...

            # Convert node coordinates to required format
            nodes_coord = []
            for idx in range(0, 2 * self.num_nodes, 2):
                nodes_coord.append([float(line[idx]), float(line[idx + 1])])

            # Compute distance matrix
            W_val = squareform(pdist(nodes_coord, metric='euclidean'))

            # Compute adjacency matrix
            if self.num_neighbors == -1:
                W = np.ones((self.num_nodes, self.num_nodes))  # Graph is fully connected
            else:
                W = np.zeros((self.num_nodes, self.num_nodes))
                # Determine k-nearest neighbors for each node
                knns = np.argpartition(W_val, kth=self.num_neighbors, axis=-1)[:, self.num_neighbors::-1]
                # Make connections
                for idx in range(self.num_nodes):
                    W[idx][knns[idx]] = 1
            np.fill_diagonal(W, 2)  # Special token for self-connections

            # Convert tour nodes to required format
            # Don't add final connection for tour/cycle
            tour_nodes = [int(node) - 1 for node in line[line.index('output') + 1:-1]][:-1]

            # Compute node and edge representation of tour + tour_len
            tour_len = 0
            nodes_target = np.zeros(self.num_nodes)
            edges_target = np.zeros((self.num_nodes, self.num_nodes))
            for idx in range(len(tour_nodes) - 1):
                i = tour_nodes[idx]
                j = tour_nodes[idx + 1]
                nodes_target[i] = idx  # node targets: ordering of nodes in tour
                edges_target[i][j] = 1
                edges_target[j][i] = 1
                tour_len += W_val[i][j]

            # Add final connection of tour in edge target
            nodes_target[j] = len(tour_nodes) - 1
            edges_target[j][tour_nodes[0]] = 1
            edges_target[tour_nodes[0]][j] = 1
            tour_len += W_val[j][tour_nodes[0]]

            # Concatenate the data
            batch_edges.append(W)
            batch_edges_values.append(W_val)
            batch_edges_target.append(edges_target)
            batch_nodes.append(nodes)
            batch_nodes_target.append(nodes_target)
            batch_nodes_coord.append(nodes_coord)
            batch_tour_nodes.append(tour_nodes)
            batch_tour_len.append(tour_len)

        # From list to tensors as a DotDict
        batch = DotDict()
        batch.edges = np.stack(batch_edges, axis=0)
        batch.edges_values = np.stack(batch_edges_values, axis=0)
        batch.edges_target = np.stack(batch_edges_target, axis=0)
        batch.nodes = np.stack(batch_nodes, axis=0)
        batch.nodes_target = np.stack(batch_nodes_target, axis=0)
        batch.nodes_coord = np.stack(batch_nodes_coord, axis=0)
        batch.tour_nodes = np.stack(batch_tour_nodes, axis=0)
        batch.tour_len = np.stack(batch_tour_len, axis=0)
        return batch

dtypeFloat = torch.cuda.FloatTensor if device.type == 'cuda' else torch.FloatTensor
dtypeLong = torch.cuda.LongTensor if device.type == 'cuda' else torch.LongTensor


In [None]:
num_nodes = 20
num_neighbors = -1    # when set to -1, it considers all the connections instead of k nearest neighbors
train_filepath = f"./tsp_data/tsp{num_nodes}_train_concorde.txt"

batch_size = 1
dataset = GoogleTSPReader(num_nodes, num_neighbors, batch_size, train_filepath)

t = time.time()
batch = next(iter(dataset))  # Generate a batch of TSPs
print("Batch generation took: {:.3f} sec".format(time.time() - t))

print("edges:", batch.edges.shape)
print("edges_values:", batch.edges_values.shape)
print("edges_targets:", batch.edges_target.shape)
print("nodes:", batch.nodes.shape)
print("nodes_target:", batch.nodes_target.shape)
print("nodes_coord:", batch.nodes_coord.shape)
print("tour_nodes:", batch.tour_nodes.shape)
print("tour_len:", batch.tour_len.shape)


In [None]:
#@title Plotting helper function

def plot_tsp(p, x_coord, W, W_val, W_target, title="default"):
    """
    Helper function to plot TSP tours.

    Args:
        p: Matplotlib figure/subplot
        x_coord: Coordinates of nodes
        W: Edge adjacency matrix
        W_val: Edge values (distance) matrix
        W_target: One-hot matrix with 1s on groundtruth/predicted edges
        title: Title of figure/subplot

    Returns:
        p: Updated figure/subplot

    """

    def _edges_to_node_pairs(W):
        """Helper function to convert edge matrix into pairs of adjacent nodes.
        """
        pairs = []
        for r in range(len(W)):
            for c in range(len(W)):
                if W[r][c] == 1:
                    pairs.append((r, c))
        return pairs

    G = nx.DiGraph(W_val)
    pos = dict(zip(range(len(x_coord)), x_coord.tolist()))
    adj_pairs = _edges_to_node_pairs(W)
    target_pairs = _edges_to_node_pairs(W_target)
    colors = ['g'] + ['b'] * (len(x_coord) - 1)  # Green for 0th node, blue for others
    nx.draw_networkx_nodes(G, pos, node_color=colors, node_size=50)
    nx.draw_networkx_edges(G, pos, edgelist=adj_pairs, alpha=0.3, width=0.5)
    nx.draw_networkx_edges(G, pos, edgelist=target_pairs, alpha=1, width=1, edge_color='r')
    p.set_title(title)
    return p

In [None]:
idx = 0
f = plt.figure(figsize=(5, 5))
a = f.add_subplot(111)
plot_tsp(a, batch.nodes_coord[idx], batch.edges[idx], batch.edges_values[idx], batch.edges_target[idx])

In [None]:
#@title Batch normalization layers
class BatchNormNode(nn.Module):
    """Batch normalization for node features.
    """

    def __init__(self, hidden_dim):
        super(BatchNormNode, self).__init__()
        self.batch_norm = nn.BatchNorm1d(hidden_dim, track_running_stats=False)

    def forward(self, x):
        """
        Args:
            x: Node features (batch_size, num_nodes, hidden_dim)

        Returns:
            x_bn: Node features after batch normalization (batch_size, num_nodes, hidden_dim)
        """
        x_trans = x.transpose(1, 2).contiguous()  # Reshape input: (batch_size, hidden_dim, num_nodes)
        x_trans_bn = self.batch_norm(x_trans)
        x_bn = x_trans_bn.transpose(1, 2).contiguous()  # Reshape to original shape
        return x_bn


class BatchNormEdge(nn.Module):
    """Batch normalization for edge features.
    """

    def __init__(self, hidden_dim):
        super(BatchNormEdge, self).__init__()
        self.batch_norm = nn.BatchNorm2d(hidden_dim, track_running_stats=False)

    def forward(self, e):
        """
        Args:
            e: Edge features (batch_size, num_nodes, num_nodes, hidden_dim)

        Returns:
            e_bn: Edge features after batch normalization (batch_size, num_nodes, num_nodes, hidden_dim)
        """
        e_trans = e.transpose(1, 3).contiguous()  # Reshape input: (batch_size, num_nodes, num_nodes, hidden_dim)
        e_trans_bn = self.batch_norm(e_trans)
        e_bn = e_trans_bn.transpose(1, 3).contiguous()  # Reshape to original
        return e_bn

In [None]:
#@title MLP layer
class MLP(nn.Module):
    """Multi-layer Perceptron for output prediction.
    """

    def __init__(self, hidden_dim, output_dim, L=2):
        super(MLP, self).__init__()
        self.L = L
        U = []
        for layer in range(self.L - 1):
            U.append(nn.Linear(hidden_dim, hidden_dim, True))
        self.U = nn.ModuleList(U)
        self.V = nn.Linear(hidden_dim, output_dim, True)
        self.dropout_mlp = nn.Dropout(0.1)

    def forward(self, x):
        """
        Args:
            x: Input features (batch_size, hidden_dim)

        Returns:
            y: Output predictions (batch_size, output_dim)
        """
        Ux = x
        for U_i in self.U:
            Ux = U_i(Ux)  # B x H
            Ux = F.relu(Ux)  # B x H
            Ux = self.dropout_mlp(Ux)
        y = self.V(Ux)  # B x O
        return y

## Эмбеддинги вершин и ребер

In [None]:
class NodeFeatures(nn.Module):
    """Convnet features for nodes.

    Using `sum` aggregation:
        x_i = U*x_i +  sum_j [ gate_ij * (V*x_j) ]

    Using `mean` aggregation:
        x_i = U*x_i + ( sum_j [ gate_ij * (V*x_j) ] / sum_j [ gate_ij] )
    """

    def __init__(self, hidden_dim, aggregation="mean"):
        super(NodeFeatures, self).__init__()
        self.aggregation = aggregation
        self.U = nn.Linear(hidden_dim, hidden_dim, True)
        self.V = nn.Linear(hidden_dim, hidden_dim, True)

    def forward(self, x, edge_gate):
        """
        Args:
            x: Node features (batch_size, num_nodes, hidden_dim)
            edge_gate: Edge gate values (batch_size, num_nodes, num_nodes, hidden_dim)

        Returns:
            x_new: Convolved node features (batch_size, num_nodes, hidden_dim)
        """
        Ux = self.U(x)  # B x V x H
        Vx = self.V(x)  # B x V x H
        Vx = Vx.unsqueeze(1)  # extend Vx from "B x V x H" to "B x 1 x V x H"
        gateVx = edge_gate * Vx  # B x V x V x H
        if self.aggregation=="mean":
            x_new = Ux + torch.sum(gateVx, dim=2) / (1e-20 + torch.sum(edge_gate, dim=2))  # B x V x H
        elif self.aggregation=="sum":
            x_new = Ux + torch.sum(gateVx, dim=2)  # B x V x H
        return x_new


class EdgeFeatures(nn.Module):
    """Convnet features for edges.

    e_ij = U*e_ij + V*(x_i + x_j)
    """

    def __init__(self, hidden_dim):
        super(EdgeFeatures, self).__init__()
        self.U = nn.Linear(hidden_dim, hidden_dim, True)
        self.V = nn.Linear(hidden_dim, hidden_dim, True)

    def forward(self, x, e):
        """
        Args:
            x: Node features (batch_size, num_nodes, hidden_dim)
            e: Edge features (batch_size, num_nodes, num_nodes, hidden_dim)

        Returns:
            e_new: Convolved edge features (batch_size, num_nodes, num_nodes, hidden_dim)
        """
        Ue = self.U(e)
        Vx = self.V(x)
        Wx = Vx.unsqueeze(1)  # Extend Vx from "B x V x H" to "B x V x 1 x H"
        Vx = Vx.unsqueeze(2)  # extend Vx from "B x V x H" to "B x 1 x V x H"
        e_new = Ue + Vx + Wx
        return e_new

## Слой Residual Gated GCN

In [None]:
class ResidualGatedGCNLayer(nn.Module):
    """Convnet layer with gating and residual connection.
    """

    def __init__(self, hidden_dim, aggregation="sum"):
        super(ResidualGatedGCNLayer, self).__init__()
        self.node_feat = NodeFeatures(hidden_dim, aggregation)
        self.edge_feat = EdgeFeatures(hidden_dim)
        self.bn_node = BatchNormNode(hidden_dim)
        self.bn_edge = BatchNormEdge(hidden_dim)
        # self.dropout_layer = nn.Dropout(0.1)

    def forward(self, x, e):
        """
        Args:
            x: Node features (batch_size, num_nodes, hidden_dim)
            e: Edge features (batch_size, num_nodes, num_nodes, hidden_dim)

        Returns:
            x_new: Convolved node features (batch_size, num_nodes, hidden_dim)
            e_new: Convolved edge features (batch_size, num_nodes, num_nodes, hidden_dim)
        """
        e_in = e
        x_in = x
        # Edge convolution
        e_tmp = self.edge_feat(x_in, e_in)  # B x V x V x H
        # Compute edge gates
        edge_gate = torch.sigmoid(e_tmp)
        # Node convolution
        x_tmp = self.node_feat(x_in, edge_gate)
        # Batch normalization
        e_tmp = self.bn_edge(e_tmp)
        x_tmp = self.bn_node(x_tmp)
        # ReLU Activation
        e = F.relu(e_tmp)
        x = F.relu(x_tmp)
        # Dropout (Optional)
        # x = self.dropout_layer(x)
        # e = self.dropout_layer(e)
        # Residual connection
        x_new = x_in + x
        e_new = e_in + e
        return x_new, e_new

## Модель Residual Gated GCN

In [None]:
from torch_geometric.nn import TransformerConv

class ResidualGatedGCNModel(nn.Module):
    """Residual Gated GCN Model for outputting predictions as edge adjacency matrices.
    """

    def __init__(self, config, dtypeFloat, dtypeLong):
        super(ResidualGatedGCNModel, self).__init__()
        self.dtypeFloat = dtypeFloat
        self.dtypeLong = dtypeLong
        # Define net parameters
        self.num_nodes = config['num_nodes']
        self.node_dim = config['node_dim']
        self.voc_nodes_in = config['voc_nodes_in']
        self.voc_nodes_out = config['num_nodes']
        self.voc_edges_in = config['voc_edges_in']
        self.voc_edges_out = config['voc_edges_out']
        self.hidden_dim = config['hidden_dim']
        self.num_layers = config['num_layers']
        self.mlp_layers = config['mlp_layers']
        self.aggregation = config['aggregation']
        # Node and edge embedding layers/lookups

        # We are using TransformerConv layer from torch geometric library!
        self.nodes_coord_embedding = TransformerConv(self.node_dim, self.hidden_dim)

        self.edges_values_embedding = nn.Linear(1, self.hidden_dim//2, bias=False)
        self.edges_embedding = nn.Embedding(self.voc_edges_in, self.hidden_dim//2)
        # Define GCN Layers
        gcn_layers = []
        for layer in range(self.num_layers):
            gcn_layers.append(ResidualGatedGCNLayer(self.hidden_dim, self.aggregation))
        self.gcn_layers = nn.ModuleList(gcn_layers)
        # Define MLP classifiers
        self.mlp_edges = MLP(self.hidden_dim, self.voc_edges_out, self.mlp_layers)

    def loss_edges(self, y_pred_edges, y_edges, edge_cw):
        """
        Loss function for edge predictions.

        Args:
            y_pred_edges: Predictions for edges (batch_size, num_nodes, num_nodes)
            y_edges: Targets for edges (batch_size, num_nodes, num_nodes)
            edge_cw: Class weights for edges loss

        Returns:
            loss_edges: Value of loss function

        """
        # Edge loss
        y = F.log_softmax(y_pred_edges, dim=3)  # B x V x V x voc_edges
        y = y.permute(0, 3, 1, 2)  # B x voc_edges x V x V
        loss_edges = nn.NLLLoss(edge_cw)
        loss = loss_edges(y.contiguous(), y_edges)
        return loss

    def forward(self, x_edges, x_edges_values, x_nodes, x_nodes_coord, y_edges, edge_cw):
        """
        Args:
            x_edges: Input edge adjacency matrix (batch_size, num_nodes, num_nodes)
            x_edges_values: Input edge distance matrix (batch_size, num_nodes, num_nodes)
            x_nodes: Input nodes (batch_size, num_nodes)
            x_nodes_coord: Input node coordinates (batch_size, num_nodes, node_dim)
            y_edges: Targets for edges (batch_size, num_nodes, num_nodes)
            edge_cw: Class weights for edges loss
            # y_nodes: Targets for nodes (batch_size, num_nodes, num_nodes)
            # node_cw: Class weights for nodes loss

        Returns:
            y_pred_edges: Predictions for edges (batch_size, num_nodes, num_nodes)
            # y_pred_nodes: Predictions for nodes (batch_size, num_nodes)
            loss: Value of loss function
        """
        # Node and edge embedding
        edge_index = torch.squeeze(x_edges).nonzero().t().contiguous()
        x = self.nodes_coord_embedding(torch.squeeze(x_nodes_coord), edge_index)
        x = torch.unsqueeze(x, 0)
        e_vals = self.edges_values_embedding(x_edges_values.unsqueeze(3))  # B x V x V x H
        e_tags = self.edges_embedding(x_edges)  # B x V x V x H
        e = torch.cat((e_vals, e_tags), dim=3)
        # GCN layers
        for layer in range(self.num_layers):
            x, e = self.gcn_layers[layer](x, e)  # B x V x H, B x V x V x H
        # MLP classifier
        y_pred_edges = self.mlp_edges(e)  # B x V x V x voc_edges_out

        # Compute loss
        edge_cw = torch.Tensor(edge_cw).type(self.dtypeFloat)  # Convert to tensors
        loss = self.loss_edges(y_pred_edges.cuda(), y_edges.cuda(), edge_cw)

        return y_pred_edges, loss

## Утилитарные функции (сохранение и загрузка моделей)

In [None]:

# Функция для сохранения состояния модели
def save_model(net, filepath):
    """Save model state.

    Args:
        net: PyTorch model
        filepath: Path to save model
    """
    torch.save(net.state_dict(), filepath)
    print(f"Model saved to {filepath}")

# Функция для загрузки состояния модели
def load_model(model_class, config, filepath, dtypeFloat, dtypeLong): # <-- Принимает model_class
    """
    Load model state into an instance of the specified model_class.
    """
    if not os.path.exists(filepath):
        print(f"Model file not found: {filepath}")
        return None
    try:
        # Создаем экземпляр ПРАВИЛЬНОГО класса
        model_instance = model_class(config, dtypeFloat, dtypeLong)
        # Оборачиваем в DataParallel ПОСЛЕ создания
        net = nn.DataParallel(model_instance)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if torch.cuda.is_available():
            net.cuda() # Сначала перемещаем обертку и модель на GPU
            # Загружаем state_dict с явным указанием map_location на случай,
            # если модель сохранялась на другом устройстве
            state_dict = torch.load(filepath, map_location=device)
            net.load_state_dict(state_dict)
        else:
            # Загружаем на CPU
            state_dict = torch.load(filepath, map_location=torch.device('cpu'))
            net.load_state_dict(state_dict)

        net.eval() # Переводим в режим оценки
        print(f"Model of type {model_class.__name__} loaded successfully from {filepath}")
        return net
    except AttributeError as ae:
         # Отлавливаем ошибку несоответствия ключей, которая может возникнуть при load_state_dict
         print(f"Error loading model state from {filepath}: AttributeError likely due to architecture mismatch - {ae}")
         return None
    except Exception as e:
        print(f"Error loading model from {filepath}: {e}")
        import traceback
        traceback.print_exc() # Печатаем полный traceback для диагностики
        return None

## Настройка гиперпараметров

In [None]:
#@title Hyperparameters

num_nodes = 20 #@param # Could also be 10, 20, or 30!
num_neighbors = -1 # Could increase it!
train_filepath = f"./tsp_data/tsp{num_nodes}_train_concorde.txt"
test_filepath = f"./tsp_data/tsp{num_nodes}_test_concorde.txt"
val_filepath = f"./tsp_data/tsp{num_nodes}_val_concorde.txt"
hidden_dim = 300 #@param
num_layers = 5 #@param
mlp_layers = 2 #@param
learning_rate = 0.001 #@param
max_epochs = 30 #@param
batches_per_epoch = 2000

variables = {'train_filepath': train_filepath,
             'val_filepath': val_filepath,
             'test_filepath': test_filepath,
             'num_nodes': num_nodes,
             'num_neighbors': num_neighbors,
             'node_dim': 2 ,
             'voc_nodes_in': 2,
             'voc_nodes_out': 2,
             'voc_edges_in': 3,
             'voc_edges_out': 2,
             'hidden_dim': hidden_dim,
             'num_layers': num_layers,
             'mlp_layers': mlp_layers,
             'aggregation': 'mean',
             'max_epochs': max_epochs,
             'val_every': 3,
             'test_every': 3,
             'batches_per_epoch': batches_per_epoch,
             'accumulation_steps': 1,
             'learning_rate': learning_rate,
             'decay_rate': 1.01,
             'batch_size': 1
             }
net = nn.DataParallel(ResidualGatedGCNModel(variables,  torch.cuda.FloatTensor, torch.cuda.LongTensor))
net.cuda()

# Compute number of network parameters
nb_param = 0
for param in net.parameters():
    nb_param += np.prod(list(param.data.size()))
print('Number of parameters:', nb_param)

## Цикл обучения

In [None]:
def train_one_epoch(net, optimizer, config):
    # Set training mode
    net.train()

    # Assign parameters
    num_nodes = config['num_nodes']
    num_neighbors = config['num_neighbors']
    batches_per_epoch = config['batches_per_epoch']
    accumulation_steps = config['accumulation_steps']
    train_filepath = config['train_filepath']
    batch_size = config['batch_size']

    # Load TSP data
    dataset = GoogleTSPReader(num_nodes, num_neighbors, batch_size, train_filepath)
    if batches_per_epoch != -1:
        batches_per_epoch = min(batches_per_epoch, dataset.max_iter)
    else:
        batches_per_epoch = dataset.max_iter

    # Convert dataset to iterable
    dataset = iter(dataset)

    # Initially set loss class weights as None
    edge_cw = None

    # Initialize running data
    running_loss = 0.0
    running_nb_data = 0

    start_epoch = time.time()
    for batch_num in range(batches_per_epoch):
        # Generate a batch of TSPs
        try:
            batch = next(dataset)
        except StopIteration:
            break

        # Convert batch to torch Variables
        x_edges = Variable(torch.tensor(batch.edges, dtype=torch.int64, device='cuda'))
        x_edges_values = Variable(torch.tensor(batch.edges_values, dtype=torch.float32, device='cuda'))
        x_nodes = Variable(torch.tensor(batch.nodes, dtype=torch.int64, device='cuda'))
        x_nodes_coord = Variable(torch.tensor(batch.nodes_coord, dtype=torch.float32, device='cuda'))
        y_edges = Variable(torch.tensor(batch.edges_target, dtype=torch.int64, device='cuda'))
        y_nodes = Variable(torch.tensor(batch.nodes_target, dtype=torch.int64, device='cuda'))

        # Compute class weights (if uncomputed)
        if type(edge_cw) != torch.Tensor:
            edge_labels = y_edges.cpu().numpy().flatten()
            edge_cw = compute_class_weight("balanced", classes=np.unique(edge_labels), y=edge_labels)

        # Forward pass
        y_preds, loss = net.forward(x_edges, x_edges_values, x_nodes, x_nodes_coord, y_edges, edge_cw)
        loss = loss.mean()  # Take mean of loss across multiple GPUs
        loss = loss / accumulation_steps  # Scale loss by accumulation steps
        loss.backward()

        # Backward pass
        if (batch_num+1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        # Update running data
        running_nb_data += 1
        running_loss += loss.data.item()* accumulation_steps  # Re-scale loss

    # Compute statistics for full epoch
    loss = running_loss/ running_nb_data

    return time.time()-start_epoch, loss

## Цикл тестирования (валидации)

In [None]:
def test(net, config, mode='test'):
    # Set evaluation mode
    net.eval()

    # Assign parameters
    num_nodes = config['num_nodes']
    num_neighbors = config['num_neighbors']
    batches_per_epoch = 1 # config['batches_per_epoch']
    val_filepath = config['val_filepath']
    test_filepath = config['test_filepath']
    batch_size = config['batch_size']

    # Load TSP data
    if mode == 'val':
        dataset = GoogleTSPReader(num_nodes, num_neighbors, batch_size, filepath=val_filepath)
    elif mode == 'test':
        dataset = GoogleTSPReader(num_nodes, num_neighbors, batch_size, filepath=test_filepath)

    # Convert dataset to iterable
    dataset = iter(dataset)

    # Initially set loss class weights as None
    edge_cw = None

    # Initialize running data
    running_loss = 0.0
    running_nb_data = 0

    with torch.no_grad():
        start_test = time.time()
        for batch_num in range(batches_per_epoch):
            # Generate a batch of TSPs
            try:
                batch = next(dataset)
            except StopIteration:
                break

            # Convert batch to torch Variables
            x_edges = Variable(torch.tensor(batch.edges, dtype=torch.int64, device='cuda'))
            x_edges_values = Variable(torch.tensor(batch.edges_values, dtype=torch.float32, device='cuda'))
            x_nodes = Variable(torch.tensor(batch.nodes, dtype=torch.int64, device='cuda'))
            x_nodes_coord = Variable(torch.tensor(batch.nodes_coord, dtype=torch.float32, device='cuda'))
            y_edges = Variable(torch.tensor(batch.edges_target, dtype=torch.int64, device='cuda'))
            y_nodes = Variable(torch.tensor(batch.nodes_target, dtype=torch.int64, device='cuda'))

            # Compute class weights (if uncomputed)
            if type(edge_cw) != torch.Tensor:
                edge_labels = y_edges.cpu().numpy().flatten()
                edge_cw = compute_class_weight("balanced", classes=np.unique(edge_labels), y=edge_labels)

            # Forward pass
            y_preds, loss = net.forward(x_edges, x_edges_values, x_nodes, x_nodes_coord, y_edges, edge_cw)
            loss = loss.mean()  # Take mean of loss across multiple GPUs

            # Update running data
            running_nb_data += 1
            running_loss += loss.data.item()

    # Compute statistics for full epoch
    loss = running_loss/ running_nb_data

    return time.time()-start_test, loss

def update_learning_rate(optimizer, lr):
  """
  Updates learning rate for given optimizer.

  Args:
      optimizer: Optimizer object
      lr: New learning rate

  Returns:
      optimizer: Updated optimizer objects
  """
  for param_group in optimizer.param_groups:
      param_group['lr'] = lr
  return optimizer

## Обучение модели

In [None]:
# Define optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=variables["learning_rate"])
# optimizer = torch.optim.Adam(net.parameters(), lr=variables['learning_rate'], weight_decay=1e-5)
val_loss_old = None

train_losses = []
val_losses = []
test_losses = []

for epoch in range(50):
    # Train
    train_time, train_loss = train_one_epoch(net, optimizer, variables)
    # Print metrics
    train_losses.append(train_loss)
    print(f"Epoch: {epoch}, Train Loss: {train_loss}")

    if epoch % variables["val_every"] == 0 or epoch == variables["max_epochs"]-1:
        # Validate
        val_time, val_loss = test(net, variables, mode='val')
        val_losses.append(val_loss)
        print(f"Epoch: {epoch}, Val Loss; {val_loss}")

        # Update learning rate
        if val_loss_old != None and val_loss > 0.99 * val_loss_old:
            variables["learning_rate"] /= variables["decay_rate"]
            optimizer = update_learning_rate(optimizer, variables["learning_rate"])

        val_loss_old = val_loss  # Update old validation loss

    if epoch % variables["test_every"] == 0 or epoch == variables["max_epochs"]-1:
        # Test
        test_time, test_loss = test(net, variables, mode='test')
        test_losses.append(test_loss)
        print(f"Epoch: {epoch}, Test Loss; {test_loss}\n")

In [None]:
filepath="./tsp_gnn_model_naive.pt"
torch.save(net.state_dict(), filepath)
print(f"Model saved to {filepath}")

In [None]:
#@title Plotting helper functions

def plot_loss_curve(train_loss, val_loss, test_loss, config):
    """
    Plot training, validation, and test loss curves.

    Parameters:
    - train_loss: List of training losses for each epoch
    - val_loss: List of validation losses
    - test_loss: List of test losses
    - config: Dictionary containing plotting configuration
    """
    # Create a figure with a specific size
    plt.figure(figsize=(15, 10))

    # Plot training loss (typically at every epoch)
    plt.plot(train_loss, color='green', label='Train Loss')

    # Plot validation loss (at specified intervals)
    val_every = config.get("val_every", 1)
    val_x = [i * val_every for i in range(len(val_loss))]
    plt.plot(val_x, val_loss, color='orange', label='Val Loss')

    # Plot test loss (at specified intervals)
    test_every = config.get("test_every", 1)
    test_x = [i * test_every for i in range(len(test_loss))]
    plt.plot(test_x, test_loss, color='purple', label='Test Loss')

    # Add labels and title
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Loss Curve")
    plt.legend()

    # Show the plot
    plt.show()

def plot_tsp_heatmap(p, x_coord, W_val, W_pred, title="default"):
    """
    Helper function to plot predicted TSP tours with edge strength denoting confidence of prediction.

    Args:
        p: Matplotlib figure/subplot
        x_coord: Coordinates of nodes
        W_val: Edge values (distance) matrix
        W_pred: Edge predictions matrix
        title: Title of figure/subplot

    Returns:
        p: Updated figure/subplot

    """

    def _edges_to_node_pairs(W):
        """Helper function to convert edge matrix into pairs of adjacent nodes.
        """
        pairs = []
        edge_preds = []
        for r in range(len(W)):
            for c in range(len(W)):
                if W[r][c] > 0.25:
                    pairs.append((r, c))
                    edge_preds.append(W[r][c])
        return pairs, edge_preds

    G = nx.Graph(W_val)
    pos = dict(zip(range(len(x_coord)), x_coord.tolist()))
    node_pairs, edge_color = _edges_to_node_pairs(W_pred)
    node_color = ['g'] + ['b'] * (len(x_coord) - 1)  # Green for 0th node, blue for others
    nx.draw_networkx_nodes(G, pos, node_color=node_color, node_size=50)
    nx.draw_networkx_edges(G, pos, edgelist=node_pairs, edge_color=edge_color, edge_cmap=plt.cm.Reds, width=0.75)
    p.set_title(title)
    return p


def plot_predictions(x_nodes_coord, x_edges, x_edges_values, y_edges, y_pred_edges, num_plots=3):
    """
    Plots groundtruth TSP tour vs. predicted tours (with beamsearch).

    Args:
        x_nodes_coord: Input node coordinates (batch_size, num_nodes, node_dim)
        x_edges: Input edge adjacency matrix (batch_size, num_nodes, num_nodes)
        x_edges_values: Input edge distance matrix (batch_size, num_nodes, num_nodes)
        y_edges: Groundtruth labels for edges (batch_size, num_nodes, num_nodes)
        y_pred_edges: Predictions for edges (batch_size, num_nodes, num_nodes)
        num_plots: Number of figures to plot

    """
    y = F.softmax(y_pred_edges, dim=3)  # B x V x V x voc_edges
    y_bins = y.argmax(dim=3)  # Binary predictions: B x V x V
    y_probs = y[:,:,:,1]  # Prediction probabilities: B x V x V
    for f_idx, idx in enumerate(np.random.choice(len(y), num_plots, replace=False)):
        f = plt.figure(f_idx, figsize=(15, 5))
        x_coord = x_nodes_coord[idx].cpu().numpy()
        W = x_edges[idx].cpu().numpy()
        W_val = x_edges_values[idx].cpu().numpy()
        W_target = y_edges[idx].cpu().numpy()
        W_sol_bins = y_bins[idx].cpu().numpy()
        W_sol_probs = y_probs[idx].cpu().numpy()
        plt1 = f.add_subplot(131)
        plot_tsp(plt1, x_coord.squeeze(), W.squeeze(), W_val.squeeze(), W_target.squeeze(), 'Groundtruth')
        plt2 = f.add_subplot(132)
        plot_tsp_heatmap(plt2, x_coord.squeeze(), W_val.squeeze(), W_sol_probs.squeeze(), 'Prediction Heatmap')
        plt.show()

## Тестирование и визуализация

In [None]:
load_ = True
dtypeFloat = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
dtypeLong = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor
if load_:
    MODEL_PATH_V1="./tsp_gnn_model_naive.pt"

    if os.path.exists(MODEL_PATH_V1):
        try:
            # Используем имя класса первой модели
            net = load_model(ResidualGatedGCNModel,variables, MODEL_PATH_V1, dtypeFloat, dtypeLong) # Убедитесь, что load_model использует ResidualGatedGCNModel
            print(f"Model 1 ({MODEL_PATH_V1}) loaded successfully.")
        except Exception as e:
            print(f"Error loading Model 1: {e}")
    else:
        print(f"Model file not found: {MODEL_PATH_V1}")

net.eval()

num_samples = 10
num_nodes = variables['num_nodes']
num_neighbors = variables['num_neighbors']
test_filepath = variables['test_filepath']
dataset = iter(GoogleTSPReader(num_nodes, num_neighbors, 1, test_filepath))


x_edges = []
x_edges_values = []
x_nodes = []
x_nodes_coord = []
y_edges = []
y_nodes = []
y_preds = []

with torch.no_grad():
    for i in range(num_samples):
        sample = next(dataset)
        # Convert batch to torch Variables
        x_edges.append(Variable(torch.LongTensor(sample.edges).type(dtypeLong), requires_grad=False))
        x_edges_values.append(Variable(torch.FloatTensor(sample.edges_values).type(dtypeFloat), requires_grad=False))
        x_nodes.append(Variable(torch.LongTensor(sample.nodes).type(dtypeLong), requires_grad=False))
        x_nodes_coord.append(Variable(torch.FloatTensor(sample.nodes_coord).type(dtypeFloat), requires_grad=False))
        y_edges.append(Variable(torch.LongTensor(sample.edges_target).type(dtypeLong), requires_grad=False))
        y_nodes.append(Variable(torch.LongTensor(sample.nodes_target).type(dtypeLong), requires_grad=False))

        # Compute class weights
        edge_labels = (y_edges[-1].cpu().numpy().flatten())
        edge_cw = compute_class_weight("balanced", classes=np.unique(edge_labels), y=edge_labels)

        # Forward pass
        y_pred, loss = net.forward(x_edges[-1], x_edges_values[-1], x_nodes[-1], x_nodes_coord[-1], y_edges[-1], edge_cw)
        y_preds.append(y_pred)


y_preds = torch.squeeze(torch.stack(y_preds))

# Plot prediction visualizations
plot_predictions(x_nodes_coord, x_edges, x_edges_values, y_edges, y_preds, num_plots=num_samples)

In [None]:
plot_loss_curve(train_losses, val_losses, test_losses, variables)

## Лучевой поиск (Beam Search)

In [None]:
#@title Beam search class
class Beamsearch(object):
    """Class for managing internals of beamsearch procedure.

    References:
        General: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/translate/beam.py
        For TSP: https://github.com/alexnowakvila/QAP_pt/blob/master/src/tsp/beam_search.py
    """

    def __init__(self, beam_size, batch_size, num_nodes,
                 dtypeFloat=torch.cuda.FloatTensor, dtypeLong=torch.cuda.LongTensor,
                 probs_type='raw', random_start=False):
        """
        Args:
            beam_size: Beam size
            batch_size: Batch size
            num_nodes: Number of nodes in TSP tours
            dtypeFloat: Float data type (for GPU/CPU compatibility)
            dtypeLong: Long data type (for GPU/CPU compatibility)
            probs_type: Type of probability values being handled by beamsearch (either 'raw'/'logits'/'argmax'(TODO))
            random_start: Flag for using fixed (at node 0) vs. random starting points for beamsearch
        """
        # Beamsearch parameters
        self.batch_size = batch_size
        self.beam_size = beam_size
        self.num_nodes = num_nodes
        self.probs_type = probs_type
        # Set data types
        self.dtypeFloat = dtypeFloat
        self.dtypeLong = dtypeLong
        # Set beamsearch starting nodes
        self.start_nodes = torch.zeros(batch_size, beam_size).type(self.dtypeLong)
        if random_start == True:
            # Random starting nodes
            self.start_nodes = torch.randint(0, num_nodes, (batch_size, beam_size)).type(self.dtypeLong)
        # Mask for constructing valid hypothesis
        self.mask = torch.ones(batch_size, beam_size, num_nodes).type(self.dtypeFloat)
        self.update_mask(self.start_nodes)  # Mask the starting node of the beam search
        # Score for each translation on the beam
        self.scores = torch.zeros(batch_size, beam_size).type(self.dtypeFloat)
        self.all_scores = []
        # Backpointers at each time-step
        self.prev_Ks = []
        # Outputs at each time-step
        self.next_nodes = [self.start_nodes]

    def get_current_state(self):
        """Get the output of the beam at the current timestep.
        """
        current_state = (self.next_nodes[-1].unsqueeze(2)
                         .expand(self.batch_size, self.beam_size, self.num_nodes))
        return current_state

    def get_current_origin(self):
        """Get the backpointers for the current timestep.
        """
        return self.prev_Ks[-1]

    def advance(self, trans_probs):
        """Advances the beam based on transition probabilities.

        Args:
            trans_probs: Probabilities of advancing from the previous step (batch_size, beam_size, num_nodes)
        """
        # Compound the previous scores (summing logits == multiplying probabilities)
        if len(self.prev_Ks) > 0:
            if self.probs_type == 'raw':
                beam_lk = trans_probs * self.scores.unsqueeze(2).expand_as(trans_probs)
            elif self.probs_type == 'logits':
                beam_lk = trans_probs + self.scores.unsqueeze(2).expand_as(trans_probs)
        else:
            beam_lk = trans_probs
            # Only use the starting nodes from the beam
            if self.probs_type == 'raw':
                beam_lk[:, 1:] = torch.zeros(beam_lk[:, 1:].size()).type(self.dtypeFloat)
            elif self.probs_type == 'logits':
                beam_lk[:, 1:] = -1e20 * torch.ones(beam_lk[:, 1:].size()).type(self.dtypeFloat)
        # Multiply by mask
        beam_lk = beam_lk * self.mask
        beam_lk = beam_lk.view(self.batch_size, -1)  # (batch_size, beam_size * num_nodes)
        # Get top k scores and indexes (k = beam_size)
        bestScores, bestScoresId = beam_lk.topk(self.beam_size, 1, True, True)
        # Update scores
        self.scores = bestScores
        # Update backpointers
        prev_k = bestScoresId // self.num_nodes  # integer division
        self.prev_Ks.append(prev_k)
        # Update outputs
        new_nodes = bestScoresId % self.num_nodes  # remainder gives the node index
        self.next_nodes.append(new_nodes)
        # Re-index mask
        perm_mask = prev_k.unsqueeze(2).expand_as(self.mask).type(self.dtypeLong)  # (batch_size, beam_size, num_nodes)
        self.mask = self.mask.gather(1, perm_mask)
        # Mask newly added nodes
        self.update_mask(new_nodes)

    def update_mask(self, new_nodes):
        """Sets new_nodes to zero in mask.
        """
        arr = (torch.arange(0, self.num_nodes).unsqueeze(0).unsqueeze(1)
               .expand_as(self.mask).type(self.dtypeLong))
        new_nodes = new_nodes.unsqueeze(2).expand_as(self.mask)
        update_mask = 1 - torch.eq(arr, new_nodes).type(self.dtypeFloat)
        self.mask = self.mask * update_mask
        if self.probs_type == 'logits':
            # Convert 0s in mask to inf
            self.mask[self.mask == 0] = 1e20

    def sort_best(self):
        """Sort the beam.
        """
        return torch.sort(self.scores, 0, True)

    def get_best(self):
        """Get the score and index of the best hypothesis in the beam.
        """
        scores, ids = self.sort_best()
        return scores[1], ids[1]

    def get_hypothesis(self, k):
        """Walk back to construct the full hypothesis.

        Args:
            k: Position in the beam to construct (usually 0s for most probable hypothesis)
        """
        assert self.num_nodes == len(self.prev_Ks) + 1
        hyp = -1 * torch.ones(self.batch_size, self.num_nodes).type(self.dtypeLong)
        for j in range(len(self.prev_Ks) - 1, -2, -1):
            hyp[:, j + 1] = self.next_nodes[j + 1].gather(1, k).view(1, self.batch_size)
            k = self.prev_Ks[j].type(self.dtypeLong).gather(1, k)
        return hyp

In [None]:
#@title Beam search helper functions
def W_to_tour_len(W, W_values):
    """Helper function to calculate tour length from edge adjacency matrix.
    """
    tour_len = 0
    for i in range(W.shape[0]):
        for j in range(W.shape[1]):
            if W[i][j] == 1:
                tour_len += W_values[i][j]
    tour_len /= 2  # Divide by 2 because adjacency matrices are symmetric
    return tour_len


def tour_nodes_to_W(nodes):
    """Helper function to convert ordered list of tour nodes to edge adjacency matrix.
    """
    W = np.zeros((len(nodes), len(nodes)))
    for idx in range(len(nodes) - 1):
        i = int(nodes[idx])
        j = int(nodes[idx + 1])
        W[i][j] = 1
        W[j][i] = 1
    # Add final connection of tour in edge target
    W[j][int(nodes[0])] = 1
    W[int(nodes[0])][j] = 1
    return W

def tour_nodes_to_tour_len(nodes, W_values):
    """Helper function to calculate tour length from ordered list of tour nodes.
    """
    tour_len = 0
    for idx in range(len(nodes) - 1):
        i = nodes[idx]
        j = nodes[idx + 1]
        tour_len += W_values[i][j]
    # Add final connection of tour in edge target
    tour_len += W_values[j][nodes[0]]
    return tour_len


def is_valid_tour(nodes, num_nodes):
    """Sanity check: tour visits all nodes given.
    """
    return sorted(nodes) == [i for i in range(num_nodes)]


def mean_tour_len_edges(x_edges_values, y_pred_edges):
    """
    Computes mean tour length for given batch prediction as edge adjacency matrices (for PyTorch tensors).

    Args:
        x_edges_values: Edge values (distance) matrix (batch_size, num_nodes, num_nodes)
        y_pred_edges: Edge predictions (batch_size, num_nodes, num_nodes, voc_edges)

    Returns:
        mean_tour_len: Mean tour length over batch
    """
    y = F.softmax(y_pred_edges, dim=3)  # B x V x V x voc_edges
    y = y.argmax(dim=3)  # B x V x V
    # Divide by 2 because edges_values is symmetric
    tour_lens = (y.float() * x_edges_values.float()).sum(dim=1).sum(dim=1) / 2
    mean_tour_len = tour_lens.sum().to(dtype=torch.float).item() / tour_lens.numel()
    return mean_tour_len


def mean_tour_len_nodes(x_edges_values, bs_nodes):
    """
    Computes mean tour length for given batch prediction as node ordering after beamsearch (for Pytorch tensors).

    Args:
        x_edges_values: Edge values (distance) matrix (batch_size, num_nodes, num_nodes)
        bs_nodes: Node orderings (batch_size, num_nodes)

    Returns:
        mean_tour_len: Mean tour length over batch
    """
    y = bs_nodes.cpu().numpy()
    W_val = x_edges_values.cpu().numpy()
    running_tour_len = 0
    for batch_idx in range(y.shape[0]):
        for y_idx in range(y[batch_idx].shape[0] - 1):
            i = y[batch_idx][y_idx]
            j = y[batch_idx][y_idx + 1]
            running_tour_len += W_val[batch_idx][i][j]
        running_tour_len += W_val[batch_idx][j][0]  # Add final connection to tour/cycle
    return running_tour_len / y.shape[0]


def beamsearch_tour_nodes(y_pred_edges, beam_size, batch_size, num_nodes, dtypeFloat, dtypeLong, probs_type='raw', random_start=False):
    """
    Performs beamsearch procedure on edge prediction matrices and returns possible TSP tours.

    Args:
        y_pred_edges: Predictions for edges (batch_size, num_nodes, num_nodes)
        beam_size: Beam size
        batch_size: Batch size
        num_nodes: Number of nodes in TSP tours
        dtypeFloat: Float data type (for GPU/CPU compatibility)
        dtypeLong: Long data type (for GPU/CPU compatibility)
        random_start: Flag for using fixed (at node 0) vs. random starting points for beamsearch

    Returns: TSP tours in terms of node ordering (batch_size, num_nodes)

    """
    if probs_type == 'raw':
        # Compute softmax over edge prediction matrix
        print("y_pred_edges shape:", y_pred_edges.shape)
        print("y_pred_edges: ", y_pred_edges)
        y = F.softmax(y_pred_edges, dim=3)  # B x V x V x voc_edges
        print("y shape:", y.shape)
        print("y: ", y)
        # Consider the second dimension only
        y = y[:, :, :, 1]  # B x V x V
        print("y shape:", y.shape)
        print("y: ", y)
    elif probs_type == 'logits':
        # Compute logits over edge prediction matrix
        y = F.log_softmax(y_pred_edges, dim=3)  # B x V x V x voc_edges
        # Consider the second dimension only
        y = y[:, :, :, 1]  # B x V x V
        y[y == 0] = -1e-20  # Set 0s (i.e. log(1)s) to very small negative number
    # Perform beamsearch
    beamsearch = Beamsearch(beam_size, batch_size, num_nodes, dtypeFloat, dtypeLong, probs_type, random_start)
    trans_probs = y.gather(1, beamsearch.get_current_state())
    print("Initial beamsearch state:", beamsearch.get_current_state())
    print("Transition probabilities shape:", trans_probs.shape)
    print("Transition probabilities:", trans_probs)
    for step in range(num_nodes - 1):
        beamsearch.advance(trans_probs)
        trans_probs = y.gather(1, beamsearch.get_current_state().type(dtypeLong))
        print(f"Step {step}: current state = {beamsearch.get_current_state()}")
        print(f"Step {step}: transition probabilities = {trans_probs}")
    # Find TSP tour with highest probability among beam_size candidates
    ends = torch.zeros(batch_size, 1).type(dtypeLong)
    hyp = beamsearch.get_hypothesis(ends)
    print("Hypotheses shape:", hyp.shape)
    print("Hypotheses:", hyp)
    return hyp


def beamsearch_tour_nodes_shortest(y_pred_edges, x_edges_values, beam_size, batch_size, num_nodes,
                                   dtypeFloat, dtypeLong, probs_type='raw', random_start=False):
    """
    Performs beamsearch procedure on edge prediction matrices and returns possible TSP tours.

    Final predicted tour is the one with the shortest tour length.
    (Standard beamsearch returns the one with the highest probability and does not take length into account.)

    Args:
        y_pred_edges: Predictions for edges (batch_size, num_nodes, num_nodes)
        x_edges_values: Input edge distance matrix (batch_size, num_nodes, num_nodes)
        beam_size: Beam size
        batch_size: Batch size
        num_nodes: Number of nodes in TSP tours
        dtypeFloat: Float data type (for GPU/CPU compatibility)
        dtypeLong: Long data type (for GPU/CPU compatibility)
        probs_type: Type of probability values being handled by beamsearch (either 'raw'/'logits'/'argmax'(TODO))
        random_start: Flag for using fixed (at node 0) vs. random starting points for beamsearch

    Returns:
        shortest_tours: TSP tours in terms of node ordering (batch_size, num_nodes)

    """
    if probs_type == 'raw':
        # Compute softmax over edge prediction matrix
        y = F.softmax(y_pred_edges, dim=3)  # B x V x V x voc_edges
        # Consider the second dimension only
        y = y[:, :, :, 1]  # B x V x V
    elif probs_type == 'logits':
        # Compute logits over edge prediction matrix
        y = F.log_softmax(y_pred_edges, dim=3)  # B x V x V x voc_edges
        # Consider the second dimension only
        y = y[:, :, :, 1]  # B x V x V
        y[y == 0] = -1e-20  # Set 0s (i.e. log(1)s) to very small negative number
    # Perform beamsearch
    beamsearch = Beamsearch(beam_size, batch_size, num_nodes, dtypeFloat, dtypeLong, probs_type, random_start)
    trans_probs = y.gather(1, beamsearch.get_current_state())
    for step in range(num_nodes - 1):
        beamsearch.advance(trans_probs)
        trans_probs = y.gather(1, beamsearch.get_current_state().type(dtypeLong))
    # Initially assign shortest_tours as most probable tours i.e. standard beamsearch
    ends = torch.zeros(batch_size, 1).type(dtypeLong)
    shortest_tours = beamsearch.get_hypothesis(ends)
    # Compute current tour lengths
    shortest_lens = [1e6] * len(shortest_tours)
    for idx in range(len(shortest_tours)):
        shortest_lens[idx] = tour_nodes_to_tour_len(shortest_tours[idx].cpu().numpy(),
                                                    x_edges_values[idx].cpu().numpy())
    # Iterate over all positions in beam (except position 0 --> highest probability)
    for pos in range(1, beam_size):
        ends = pos * torch.ones(batch_size, 1).type(dtypeLong)  # New positions
        hyp_tours = beamsearch.get_hypothesis(ends)
        for idx in range(len(hyp_tours)):
            hyp_nodes = hyp_tours[idx].cpu().numpy()
            hyp_len = tour_nodes_to_tour_len(hyp_nodes, x_edges_values[idx].cpu().numpy())
            # Replace tour in shortest_tours if new length is shorter than current best
            if hyp_len < shortest_lens[idx] and is_valid_tour(hyp_nodes, num_nodes):
                shortest_tours[idx] = hyp_tours[idx]
                shortest_lens[idx] = hyp_len
    return shortest_tours

In [None]:
def plot_predictions_with_beam_search(x_nodes_coord, x_edges, x_edges_values, y_edges, y_pred_edges, beam_search_predictions, num_plots=3):
    """
    Plots groundtruth TSP tour, predicted tours (with beamsearch), and beam search predictions.

    Args:
        x_nodes_coord: Input node coordinates (batch_size, num_nodes, node_dim)
        x_edges: Input edge adjacency matrix (batch_size, num_nodes, num_nodes)
        x_edges_values: Input edge distance matrix (batch_size, num_nodes, num_nodes)
        y_edges: Groundtruth labels for edges (batch_size, num_nodes, num_nodes)
        y_pred_edges: Predictions for edges (batch_size, num_nodes, num_nodes)
        beam_search_predictions: Predicted node ordering from beam search (batch_size, num_nodes)
        num_plots: Number of figures to plot
    """
    y = F.softmax(y_pred_edges, dim=3)  # B x V x V x voc_edges
    y_bins = y.argmax(dim=3)  # Binary predictions: B x V x V
    y_probs = y[:,:,:,1]  # Prediction probabilities: B x V x V

    for f_idx, idx in enumerate(np.random.choice(len(y), num_plots, replace=False)):
        f = plt.figure(f_idx, figsize=(15, 6))

        # Extract relevant data for the selected sample
        x_coord = x_nodes_coord[idx].cpu().numpy()
        W = x_edges[idx].cpu().numpy()
        W_val = x_edges_values[idx].cpu().numpy()
        W_target = y_edges[idx].cpu().numpy()
        W_sol_bins = y_bins[idx].cpu().numpy()
        W_sol_probs = y_probs[idx].cpu().numpy()

        # Groundtruth plot
        plt1 = f.add_subplot(131)
        plot_tsp(plt1, x_coord.squeeze(), W.squeeze(), W_val.squeeze(), W_target.squeeze(), 'Groundtruth')

        # Prediction heatmap plot
        plt2 = f.add_subplot(132)
        plot_tsp_heatmap(plt2, x_coord.squeeze(), W_val.squeeze(), W_sol_probs.squeeze(), 'Prediction Heatmap')

        # Beam search plot
        beam_search_route = beam_search_predictions[idx].cpu().numpy().tolist()  # Convert to list of node indices
        W_pred = tour_nodes_to_W(beam_search_route)  # Convert route to edge matrix

        # Compute the tour length for the beam search tour
        tour_length = tour_nodes_to_tour_len(beam_search_route, W_val.squeeze())

        # Dummy target (since we don't have groundtruth for beam search)
        dummy_target = np.zeros_like(W_val.squeeze())

        # Beam search tour plot
        plt3 = f.add_subplot(133)
        plot_tsp(plt3, x_coord.squeeze(), W_pred, W_val.squeeze(), dummy_target,
                 title=f"Beam Search Tour (Length: {tour_length:.2f})")

        # Show the plot
        plt.show()


In [None]:
# # Define beam search parameters:
beam_size = 5  # or any beam size you prefer
batch_size = y_preds.shape[0]
num_nodes = y_preds.shape[1]
dtypeFloat = torch.cuda.FloatTensor if y_preds.is_cuda else torch.FloatTensor
dtypeLong = torch.cuda.LongTensor if y_preds.is_cuda else torch.LongTensor

# Call the beam search helper function for node ordering based on edge predictions:
predicted_tours = beamsearch_tour_nodes(
    y_preds,
    beam_size,
    batch_size,
    num_nodes,
    dtypeFloat,
    dtypeLong,
    probs_type='raw',    # or 'logits'
    random_start=True
)

print("Predicted TSP tours:", predicted_tours)



In [None]:
plot_predictions_with_beam_search(x_nodes_coord, x_edges, x_edges_values, y_edges, y_preds, predicted_tours, num_plots=num_samples)

## Модель Residual Gated GNN (Упрощенная модель без TransformerConv)

In [None]:
class ResidualGatedGCNLayer_v2(nn.Module):
    """Residual Gated GCN layer.
    """
    def __init__(self, hidden_dim, aggregation="mean"):
        super(ResidualGatedGCNLayer_v2, self).__init__()
        self.node_feat = NodeFeatures(hidden_dim, aggregation)
        self.edge_feat = EdgeFeatures(hidden_dim)
        self.bn_node = BatchNormNode(hidden_dim)
        self.bn_edge = BatchNormEdge(hidden_dim)
        self.dropout_layer = nn.Dropout(0.1)

    def forward(self, x, e):
        """
        Args:
            x: Node features (batch_size, num_nodes, hidden_dim)
            e: Edge features (batch_size, num_nodes, num_nodes, hidden_dim)

        Returns:
            x_new: Updated node features (batch_size, num_nodes, hidden_dim)
            e_new: Updated edge features (batch_size, num_nodes, num_nodes, hidden_dim)
        """
        # Gate edge features
        edge_gate = torch.sigmoid(e)

        # Node and edge feature transformation
        x_new = self.node_feat(x, edge_gate)
        e_new = self.edge_feat(x, e)

        # Batch normalization
        x_new = self.bn_node(x_new)
        e_new = self.bn_edge(e_new)

        # Residual connection
        x_new = F.relu(x + x_new)
        e_new = F.relu(e + e_new)

        x_new = self.dropout_layer(x_new)
        e_new = self.dropout_layer(e_new)

        return x_new, e_new

class ResidualGatedGCNModel_v2(nn.Module):
    """Residual Gated GCN Model for TSP.
    """
    def __init__(self, config, dtypeFloat, dtypeLong):
        super(ResidualGatedGCNModel_v2, self).__init__()
        self.config = config
        self.dtypeFloat = dtypeFloat
        self.dtypeLong = dtypeLong

        # Initial embedding layers
        self.node_embedding = nn.Linear(config['node_dim'], config['hidden_dim'], bias=False)
        self.edge_embedding = nn.Linear(1, config['hidden_dim'], bias=False)

        # GCN layers
        gcn_layers = []
        for layer in range(config['num_layers']):
            gcn_layers.append(ResidualGatedGCNLayer_v2(config['hidden_dim'], config['aggregation']))
        self.gcn_layers = nn.ModuleList(gcn_layers)

        # MLP for edge classification
        self.mlp_edges = MLP(config['hidden_dim'], config['voc_edges_out'], config['mlp_layers'])

    def forward(self, x_edges, x_edges_values, x_nodes, x_nodes_coord, y_edges=None, edge_cw=None):
        """
        Args:
            x_edges: Input edge connectivity (batch_size, num_nodes, num_nodes)
            x_edges_values: Input edge distances (batch_size, num_nodes, num_nodes)
            x_nodes: Input node features (batch_size, num_nodes)
            x_nodes_coord: Input node coordinates (batch_size, num_nodes, node_dim)
            y_edges: Edge labels (batch_size, num_nodes, num_nodes)
            edge_cw: Class weights for edge loss

        Returns:
            y_pred_edges: Edge predictions (batch_size, num_nodes, num_nodes, voc_edges_out)
            loss: Cross-entropy loss
        """
        # Problem size
        batch_size, num_nodes = x_nodes.size()

        # Initial embeddings
        x = self.node_embedding(x_nodes_coord)  # B x V x H

        # Create edge features
        edge_feat = x_edges_values.unsqueeze(3)  # B x V x V x 1
        e = self.edge_embedding(edge_feat)  # B x V x V x H

        # Apply GCN layers
        for layer in self.gcn_layers:
            x, e = layer(x, e)

        # MLP classifier for edges
        # To apply the MLP on each edge, we reshape the tensor
        e_flat = e.reshape(-1, self.config['hidden_dim'])  # (B*V*V) x H
        y_pred_edges_flat = self.mlp_edges(e_flat)  # (B*V*V) x 2
        y_pred_edges = y_pred_edges_flat.reshape(batch_size, num_nodes, num_nodes, -1)  # B x V x V x 2

        # Compute loss if training
        if y_edges is not None:
            # Convert edge predictions and labels for loss computation
            # We need to flatten them to (batch_size*num_nodes*num_nodes, num_classes)
            y_pred_edges_flat = y_pred_edges.reshape(-1, self.config['voc_edges_out'])
            y_edges_flat = y_edges.reshape(-1)
            
            device = y_pred_edges_flat.device # Определяем устройство

            loss = None # Инициализируем loss
            # Compute loss with class weights if provided
            if edge_cw is not None:
                if not isinstance(edge_cw, torch.Tensor):
                     # Если вдруг передали не тензор (хотя должны тензор dummy_edge_cw)
                     edge_cw = torch.tensor(edge_cw, dtype=torch.float, device=device)
                else:
                     # Просто перемещаем на device и проверяем тип
                     edge_cw = edge_cw.to(device=device, dtype=torch.float)

                loss_fct = nn.CrossEntropyLoss(weight=edge_cw)
                loss = loss_fct(y_pred_edges_flat, y_edges_flat)
            else:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(y_pred_edges_flat, y_edges_flat)

            return y_pred_edges, loss
        else:
            return y_pred_edges

# Гибридные методы решения задачи коммивояжера

### Метод имитации отжига + 2-opt (Simulated Annealing + 2-opt outer loop)

In [None]:
import numpy as np
import math
import random
import time

class TSP_2opt_SA_outer:
    def __init__(self, distance_matrix):
        """
        Initialize the solver with a distance matrix.

        Args:
            distance_matrix: 2D numpy array where distance_matrix[i][j] is the distance from city i to city j
        """
        self.distances = distance_matrix
        self.num_cities = len(distance_matrix)

    def calculate_tour_length(self, tour):
        """Calculate the total length of a tour."""
        return sum(self.distances[tour[i]][tour[(i + 1) % self.num_cities]] for i in range(self.num_cities))

    def two_opt_swap(self, tour, i, j):
        """
        Perform a 2-opt swap by reversing the segment between positions i and j.
        Returns a new tour.
        """
        new_tour = tour.copy()
        new_tour[i:j+1] = new_tour[i:j+1][::-1]
        return new_tour

    def delta_two_opt(self, tour, i, j):
        """
        Calculate the cost difference (delta) for a 2-opt swap between positions i and j in O(1) time.
        """
        n = self.num_cities
        a, b = tour[i - 1], tour[i]
        c, d = tour[j], tour[(j + 1) % n]
        delta = (self.distances[a][c] + self.distances[b][d]) - (self.distances[a][b] + self.distances[c][d])
        return delta

    def get_random_neighbor(self, tour):
        """
        Get a random neighbor by performing a 2-opt swap between two random indices.
        Returns the new tour and the cost difference (delta).
        """
        i, j = sorted(random.sample(range(self.num_cities), 2))
        new_tour = self.two_opt_swap(tour, i, j)
        delta = self.delta_two_opt(tour, i, j)
        return new_tour, delta, i, j

    def acceptance_probability(self, delta, temperature):
        """
        Calculate the acceptance probability.
        Returns 1.0 for improvements.
        """
        return 1.0 if delta < 0 else math.exp(-delta / temperature)

    def double_bridge_move(self, tour):
        """
        Perform a double-bridge move to perturb the current tour.
        This move breaks the tour into four segments and rearranges them.
        """
        n = self.num_cities
        pos1 = random.randint(1, n // 4)
        pos2 = random.randint(pos1 + 1, n // 2)
        pos3 = random.randint(pos2 + 1, 3 * n // 4)
        new_tour = tour[0:pos1] + tour[pos3:] + tour[pos2:pos3] + tour[pos1:pos2]
        return new_tour

    def solve_sa(self, initial_tour, sa_params, use_perturbation=True):
        """
        Run the SA phase from a given starting tour.
        Returns the best tour, its length, number of iterations, and elapsed time.
        """
        current_tour = initial_tour.copy()
        current_length = self.calculate_tour_length(current_tour)
        best_tour, best_length = current_tour.copy(), current_length

        temperature = sa_params.get('initial_temp', 500.0)
        cooling_rate = sa_params.get('cooling_rate', 0.9995)
        min_temp = sa_params.get('min_temp', 1e-6)
        max_iterations = sa_params.get('max_iterations', 10000)
        max_no_improve = sa_params.get('max_no_improve', 1000)
        perturb_interval = sa_params.get('perturb_interval', 500)
        perturb_prob = sa_params.get('perturb_prob', 0.1)

        iterations = 0
        no_improve = 0
        start_time = time.time()

        while iterations < max_iterations and temperature > min_temp and no_improve < max_no_improve:
            iterations += 1
            # Optionally apply a perturbation
            if use_perturbation and iterations % perturb_interval == 0 and random.random() < perturb_prob:
                perturbed = self.double_bridge_move(current_tour)
                perturbed_length = self.calculate_tour_length(perturbed)
                # Accept perturbation if it improves or with small probability
                if perturbed_length < current_length or random.random() < 0.2:
                    current_tour = perturbed
                    current_length = perturbed_length
                    no_improve = 0

            neighbor, delta, i_swap, j_swap = self.get_random_neighbor(current_tour)
            if self.acceptance_probability(delta, temperature) > random.random():
                current_tour = neighbor
                current_length += delta
                if current_length < best_length:
                    best_tour = current_tour.copy()
                    best_length = current_length
                    no_improve = 0
                else:
                    no_improve += 1
            else:
                no_improve += 1
            temperature *= cooling_rate

        elapsed = time.time() - start_time
        return best_tour, best_length, iterations, elapsed

    def improve_tour_with_2opt(self, tour, max_iterations=500):
        """
        Run deterministic 2-opt local search.
        Returns the improved tour, its length, and the number of iterations.
        """
        best_tour = tour.copy()
        best_length = self.calculate_tour_length(best_tour)
        iterations = 0
        improved = True

        while improved and iterations < max_iterations:
            improved = False
            iterations += 1
            for i in range(1, self.num_cities - 1):
                for j in range(i + 1, self.num_cities):
                    if i == 1 and j == self.num_cities - 1:
                        continue  # Avoid trivial full reversal
                    delta = self.delta_two_opt(best_tour, i, j)
                    if delta < 0:
                        best_tour = self.two_opt_swap(best_tour, i, j)
                        best_length += delta
                        improved = True
                        break
                if improved:
                    break
        return best_tour, best_length, iterations

    def solve(self, outer_iterations=10, sa_params=None, local_search_after_sa=True, use_perturbation=True):
        """
        Hybrid iterated approach: repeatedly run SA followed by (optional) 2-opt.
        This iterated local search allows restarting from perturbed solutions.

        Args:
            outer_iterations: Number of outer iterations.
            sa_params: Parameters for the SA phase.
            local_search_after_sa: Whether to perform 2-opt after SA.
            use_perturbation: Whether to use the double-bridge perturbation in SA.

        Returns:
            overall_best_tour, overall_best_length: Best solution found.
        """
        if sa_params is None:
            sa_params = {
                'initial_temp': 500.0,
                'cooling_rate': 0.9995,
                'min_temp': 1e-6,
                'max_iterations': 10000,
                'max_no_improve': 1000,
                'perturb_interval': 500,
                'perturb_prob': 0.1
            }

        overall_best_tour = None
        overall_best_length = float('inf')
        overall_times = []

        # Start with a random tour
        current_tour = list(range(self.num_cities))
        random.shuffle(current_tour)

        for outer in range(1, outer_iterations + 1):
            outer_start = time.time()
            print(f"Outer iteration {outer}/{outer_iterations}")
            # Run SA from the current tour
            sa_tour, sa_length, sa_iters, sa_time = self.solve_sa(current_tour, sa_params, use_perturbation)
            print(f"  SA: {sa_iters} iterations, {sa_time:.4f} seconds, tour length: {sa_length:.2f}")

            # Apply 2-opt improvement if enabled
            if local_search_after_sa:
                print("  Applying deterministic 2-opt local search...")
                opt_tour, opt_length, opt_iters = self.improve_tour_with_2opt(sa_tour)
                print(f"  2-opt: {opt_iters} iterations, final tour length: {opt_length:.2f}")
                candidate_length = opt_length
                candidate_tour = opt_tour
            else:
                candidate_length = sa_length
                candidate_tour = sa_tour

            total_outer = time.time() - outer_start
            print(f"  Outer iteration time: {total_outer:.4f} seconds")

            # Update overall best if found
            if candidate_length < overall_best_length:
                overall_best_length = candidate_length
                overall_best_tour = candidate_tour.copy()
                print(f"  New best solution found: {overall_best_length:.2f}")
            # Use the candidate as starting point for next outer iteration
            current_tour = candidate_tour.copy()
            overall_times.append(total_outer)
            print("")

        avg_time = sum(overall_times)/len(overall_times) if overall_times else 0.0
        print(f"Best solution overall: {overall_best_length:.2f}")
        print(f"Mean outer iteration time: {avg_time:.4f} seconds")
        return overall_best_tour, overall_best_length


## Метод ветвей и отсечений и интеграция с графовыми нейронными сетями

In [None]:
import numpy as np
import torch
from scipy.spatial.distance import pdist, squareform
import networkx as nx
import pulp
from pulp import LpVariable, LpProblem, lpSum, LpMinimize, value

class GNNBranchCutSolver:
    """
    Solver that integrates GNN predictions with Branch & Cut for TSP
    """
    def __init__(self, gnn_model, threshold=0.6, fixing_percentage=0.2, elimination_percentage=0.2):
        """
        Args:
            gnn_model: Trained GNN model
            threshold: Probability threshold for edge fixing
            fixing_percentage: Percentage of highest probability edges to fix
            elimination_percentage: Percentage of lowest probability edges to eliminate
        """
        self.gnn_model = gnn_model
        self.threshold = threshold
        self.fixing_percentage = fixing_percentage
        self.elimination_percentage = elimination_percentage
        print(self.threshold)
        print(self.fixing_percentage)
        print(self.elimination_percentage)

    def get_edge_probabilities(self, nodes_coord):
        # Проверка типа входных данных
        if not isinstance(nodes_coord, np.ndarray):
             try:
                 # Попытка конвертировать, если это тензор или список списков
                 nodes_coord = np.array(nodes_coord)
             except Exception as e:
                 raise TypeError(f"Input nodes_coord must be a NumPy array or convertible. Error: {e}")
        if nodes_coord.ndim != 2 or nodes_coord.shape[1] != 2:
            raise ValueError(f"Input nodes_coord must have shape (num_nodes, 2), but got {nodes_coord.shape}")

        num_nodes = nodes_coord.shape[0]
        batch_size = 1 # Обрабатываем по одному экземпляру

        # Подготовка входных тензоров
        try:
            # 1. Матрица расстояний
            dist_matrix = squareform(pdist(nodes_coord, metric='euclidean'))

            # 2. Входные тензоры для GNN
            # x_edges: Матрица связности (1 - связь, 2 - self-loop, 0 - нет)
            x_edges_np = np.ones((batch_size, num_nodes, num_nodes), dtype=np.int64)
            # Заполняем диагональ специальным значением (например, 2, как в вашем GoogleTSPReader)
            # Если ваша модель ожидает 0 на диагонали, измените это.
            np.fill_diagonal(x_edges_np[0], 2)

            # x_edges_values: Матрица расстояний
            x_edges_values_np = np.expand_dims(dist_matrix, axis=0).astype(np.float32)

            # x_nodes: Признаки узлов (просто единицы для TSP)
            x_nodes_np = np.ones((batch_size, num_nodes), dtype=np.int64)

            # x_nodes_coord: Координаты узлов
            x_nodes_coord_np = np.expand_dims(nodes_coord, axis=0).astype(np.float32)

            # Определяем устройство и типы тензоров PyTorch
            device = next(self.gnn_model.parameters()).device # Получаем устройство модели
            dtypeFloat = torch.cuda.FloatTensor if device.type == 'cuda' else torch.FloatTensor
            dtypeLong = torch.cuda.LongTensor if device.type == 'cuda' else torch.LongTensor

            # Конвертация в тензоры PyTorch и перемещение на device
            x_edges = torch.from_numpy(x_edges_np).type(dtypeLong).to(device)
            x_edges_values = torch.from_numpy(x_edges_values_np).type(dtypeFloat).to(device)
            x_nodes = torch.from_numpy(x_nodes_np).type(dtypeLong).to(device)
            x_nodes_coord_tensor = torch.from_numpy(x_nodes_coord_np).type(dtypeFloat).to(device)

        except Exception as e:
            print(f"Error during input tensor preparation in get_edge_probabilities: {e}")
            # Возвращаем пустой массив или поднимаем исключение
            return np.zeros((num_nodes, num_nodes))

        # ---> Создание фиктивных y_edges и edge_cw <---
        dummy_y_edges = torch.zeros(batch_size, num_nodes, num_nodes, dtype=torch.long, device=device)
        # Веса [1.0, 1.0] т.к. два класса (ребро есть/нет)
        dummy_edge_cw = torch.ones(2, dtype=torch.float, device=device)

        # Get predictions from GNN
        self.gnn_model.eval() # Убедимся, что модель в режиме оценки
        with torch.no_grad():
            # ---> Передаем фиктивные аргументы <---
            try:
                 result = self.gnn_model(x_edges, x_edges_values, x_nodes, x_nodes_coord_tensor,
                                         dummy_y_edges, dummy_edge_cw)
                 # Проверяем, что вернула модель
                 if isinstance(result, tuple):
                      edge_preds = result[0] # Берем предсказания (логиты)
                 else:
                      edge_preds = result
            except Exception as e:
                 print(f"Error during model forward pass in get_edge_probabilities: {e}")
                 return np.zeros((num_nodes, num_nodes)) # Возвращаем нули при ошибке
            # ---> КОНЕЦ ВЫЗОВА МОДЕЛИ <---


        # Convert to probabilities using softmax
        try:
             # Убедимся, что edge_preds имеет правильную размерность [B, V, V, Voc]
             if edge_preds.dim() == 4 and edge_preds.shape[0] == 1 and \
                edge_preds.shape[1] == num_nodes and edge_preds.shape[2] == num_nodes:
                 # Берем вероятности для класса 1 (ребро есть)
                 edge_probs_tensor = torch.softmax(edge_preds[0], dim=-1)[:, :, 1]
                 edge_probs = edge_probs_tensor.cpu().numpy()
             else:
                 # Попытка обработать другие формы, если squeeze произошел где-то
                 if edge_preds.dim()==3 and edge_preds.shape[0] == num_nodes and edge_preds.shape[1] == num_nodes:
                      edge_probs_tensor = torch.softmax(edge_preds, dim=-1)[:, :, 1]
                      edge_probs = edge_probs_tensor.cpu().numpy()
                 else:
                      raise ValueError(f"Unexpected output shape from GNN model: {edge_preds.shape}. Expected 4D [1, N, N, Voc] or 3D [N, N, Voc].")

        except Exception as e_prob:
             print(f"Error processing edge_preds shape {edge_preds.shape} in get_edge_probabilities: {e_prob}")
             return np.zeros((num_nodes, num_nodes))

        return edge_probs

    # Добавляем параметр optimal_edges_pure для сравнения
    def solve_tsp(self, nodes_coord, optimal_edges_pure=None):
        """Solve TSP using GNN predictions and Branch & Cut.

        Args:
            nodes_coord: Coordinates of nodes (num_nodes, 2)
            optimal_edges_pure: Optional set of tuples representing optimal edges from pure B&C for debugging

        Returns:
            tour: Optimal tour as list of node indices
            tour_length: Length of optimal tour
            solve_time: Time taken to solve ILP part (excluding GNN prediction)
        """
        total_start_time = time.time() # Время для всего метода

        num_nodes = nodes_coord.shape[0]
        dist_matrix = squareform(pdist(nodes_coord, metric='euclidean'))

        # --- Время GNN ---
        gnn_start_time = time.time()
        edge_probs = self.get_edge_probabilities(nodes_coord)
        gnn_time = time.time() - gnn_start_time


        # --- Время ILP ---
        ilp_start_time = time.time()
        model = LpProblem("TSP_GNN_BranchCut", LpMinimize)
        x = {}
        for i in range(num_nodes):
            for j in range(i+1, num_nodes):
                x[i, j] = LpVariable(f'x_{i}_{j}', cat='Binary')

        def get_edge_var(i, j):
            # Убедимся, что i < j для доступа к словарю x
            idx1, idx2 = min(i, j), max(i, j)
            if idx1 == idx2: return None # Не должно быть, но на всякий случай
            return x.get((idx1, idx2))

        model += lpSum([dist_matrix[i][j] * get_edge_var(i, j) for i in range(num_nodes) for j in range(i+1, num_nodes)])

        for i in range(num_nodes):
             # Используем get_edge_var для правильного суммирования
            model += lpSum([get_edge_var(i, j) for j in range(num_nodes) if i != j]) == 2

        fixed_edges_count = 0
        eliminated_edges_count = 0

        # --- Логирование Фиксации Ребер ---
        if self.fixing_percentage > 0:
            indices = [(i, j) for i in range(num_nodes) for j in range(i+1, num_nodes)]
            edge_probs_triu = np.array([edge_probs[i, j] for i, j in indices])
            sorted_indices = np.argsort(-edge_probs_triu) # Descending
            num_edges_to_fix_total = len(indices)
            num_edges_to_fix_limit = int(self.fixing_percentage * num_edges_to_fix_total)


            for k, idx in enumerate(sorted_indices[:num_edges_to_fix_limit]):
                i, j = indices[idx]
                prob = edge_probs[i,j]
                edge_tuple = tuple(sorted((i, j))) # Для сравнения с optimal_edges_pure

                # Применяем порог threshold
                if prob > self.threshold:
                    model += get_edge_var(i, j) == 1
                    fixed_edges_count += 1

        # --- Логирование Удаления Ребер ---
        if self.elimination_percentage > 0:
            indices = [(i, j) for i in range(num_nodes) for j in range(i+1, num_nodes)]
            edge_probs_triu = np.array([edge_probs[i, j] for i, j in indices])
            sorted_indices = np.argsort(edge_probs_triu) # Ascending
            num_edges_to_elim_total = len(indices)
            num_edges_to_elim_limit = int(self.elimination_percentage * num_edges_to_elim_total)


            for k, idx in enumerate(sorted_indices[:num_edges_to_elim_limit]):
                i, j = indices[idx]
                prob = edge_probs[i, j]
                edge_tuple = tuple(sorted((i, j))) # Для сравнения

                model += get_edge_var(i, j) == 0
                eliminated_edges_count += 1

        # --- Решение и Цикл Устранения Подтуров ---
        subtour_elim_iterations = 0
        total_pulp_solve_time = 0.0

        while True:
            iter_solve_start = time.time()
            # Подавляем вывод PuLP, если не нужен
            solver_kwargs = {'msg': False}
            status = model.solve(pulp.PULP_CBC_CMD(**solver_kwargs)) # Используем встроенный или внешний решатель
            iter_solve_time = time.time() - iter_solve_start
            total_pulp_solve_time += iter_solve_time

            if pulp.LpStatus[status] != 'Optimal':
                 print(f"[WARN GNN Solv] PuLP solver did not find an optimal solution (Status: {pulp.LpStatus[status]})")
                 # Можно вернуть None или поднять исключение
                 ilp_time = time.time() - ilp_start_time
                 total_time = time.time() - total_start_time
                 print(f"[DEBUG GNN Solv] ILP + Subtour time: {ilp_time:.4f}s (PuLP solve time: {total_pulp_solve_time:.4f}s)")
                 print(f"[DEBUG GNN Solv] Total GNN solve method time: {total_time:.4f}s")
                 return [], float('inf'), total_time # Возвращаем индикатор ошибки


            edges = []
            for i in range(num_nodes):
                for j in range(i+1, num_nodes):
                    var = get_edge_var(i,j)
                    if var is not None and value(var) > 0.5:
                        edges.append((i, j))

            G = nx.Graph()
            G.add_edges_from(edges)
            components = list(nx.connected_components(G))

            if len(components) == 1:
                break

            subtour_elim_iterations += 1

            # Добавляем ограничения только для компонент < N узлов
            constraints_added_this_iter = 0
            for component in components:
                if len(component) < num_nodes:
                    component_nodes = list(component)
                    # Формируем ограничение: sum x(i,j) <= |S|-1 для i,j in S
                    subtour_edges_sum = lpSum([get_edge_var(i, j)
                                               for i in component_nodes
                                               for j in component_nodes if i < j]) # Используем i < j
                    # Имя ограничения для отладки
                    constraint_name = f"SubtourElim_{subtour_elim_iterations}_{constraints_added_this_iter}"
                    model += subtour_edges_sum <= len(component_nodes) - 1, constraint_name
                    constraints_added_this_iter += 1
            if constraints_added_this_iter == 0 and len(components)>1:
                 print("[ERROR GNN Solv] Found multiple components but couldn't add elimination constraints!")
                 ilp_time = time.time() - ilp_start_time
                 total_time = time.time() - total_start_time
                 print(f"[DEBUG GNN Solv] ILP + Subtour time: {ilp_time:.4f}s (PuLP solve time: {total_pulp_solve_time:.4f}s)")
                 print(f"[DEBUG GNN Solv] Total GNN solve method time: {total_time:.4f}s")
                 return [], float('inf'), total_time


        ilp_time = time.time() - ilp_start_time # Время на всю ILP часть
        total_time = time.time() - total_start_time # Общее время метода


        tour = self.edges_to_tour(edges) # Передаем num_nodes
        tour_length = self.calculate_tour_length(tour, dist_matrix)

        return tour, tour_length, total_time # Возвращаем общее время метода


    def edges_to_tour(self, edges):
        """Convert edge list to tour.

        Args:
            edges: List of edges as (i,j) tuples

        Returns:
            tour: List of node indices in tour order
        """
        G = nx.Graph()
        G.add_edges_from(edges)

        # Если число вершин меньше, чем ожидается, добавим изолированные узлы
        num_nodes = max(max(e) for e in edges) + 1 if edges else 0
        for i in range(num_nodes):
            if i not in G:
                G.add_node(i)

        # Пытаемся найти цикл через встроенную функцию
        try:
            cycle = nx.find_cycle(G, source=0)
            # Формируем тур: начальный узел + последовательность второго элемента ребер
            tour = [cycle[0][0]] + [j for (_, j) in cycle]
            # Если цикл замкнут (начало = конец), убираем повторное появление начального узла
            if tour[0] == tour[-1]:
                tour.pop()
            # Если тур охватывает не все узлы, пробуем дополнить DFS‑обходом
            if len(tour) < num_nodes:
                raise ValueError("Cycle does not cover all nodes")
            return tour
        except Exception as e:
            # Если nx.find_cycle не удалось или тур неполный, строим тур вручную через DFS
            tour = []
            visited = set()
            def dfs(node, prev):
                visited.add(node)
                tour.append(node)
                for nbr in G.neighbors(node):
                    if nbr != prev and nbr not in visited:
                        dfs(nbr, node)
            dfs(0, None)
            # Если тур не полный, дополним отсутствующими узлами (без гарантии гамильтонова цикла)
            for node in range(num_nodes):
                if node not in visited:
                    tour.append(node)
            return tour

    def calculate_tour_length(self, tour, dist_matrix):
        """Calculate tour length.

        Args:
            tour: List of node indices
            dist_matrix: Distance matrix

        Returns:
            tour_length: Total tour length
        """
        tour_length = 0
        for i in range(len(tour)):
            j = (i + 1) % len(tour)
            tour_length += dist_matrix[tour[i]][tour[j]]
        return tour_length

## Concorde TSP датасет

In [None]:
import numpy as np
from scipy.spatial.distance import pdist, squareform
import networkx as nx
from sklearn.utils import shuffle
import math
from typing import List, Tuple, Optional, Dict, Any, Iterator
import os
from tqdm.auto import tqdm

class ConcordeTSPInstance:
    """Структура для хранения данных одного экземпляра TSP."""
    def __init__(self, name: str, coordinates: np.ndarray, dist_matrix: np.ndarray,
                 optimal_tour: Optional[List[int]] = None, optimal_length: Optional[float] = None):
        self.name = name
        self.coordinates = coordinates
        self.dist_matrix = dist_matrix
        self.optimal_tour = optimal_tour
        self.optimal_length = optimal_length
        self.num_nodes = coordinates.shape[0]

    def get_nx_graph(self) -> nx.Graph:
        """Создает и возвращает граф NetworkX с весами ребер."""
        n = self.num_nodes
        graph = nx.complete_graph(n)
        for i in range(n):
            for j in range(i + 1, n):
                graph.edges[i, j]['weight'] = self.dist_matrix[i, j]
        return graph

    @staticmethod
    def calculate_tour_len(tour: List[int], dist_matrix: np.ndarray) -> float:
        """Статический метод для вычисления длины тура."""
        length = 0.0
        n = len(tour)
        if n == 0: return 0.0
        for i in range(n):
            u = tour[i]
            v = tour[(i + 1) % n]
            if 0 <= u < dist_matrix.shape[0] and 0 <= v < dist_matrix.shape[0]:
                 length += dist_matrix[u, v]
            else:
                 print(f"Warning: Invalid index in tour. u={u}, v={v}, matrix_shape={dist_matrix.shape}")
                 return float('inf')
        return length


class ConcordeTSPReader:
    """
    Читает файлы датасета Concorde TSP и возвращает экземпляры задач.
    Позволяет ограничить максимальное количество загружаемых экземпляров.
    """
    def __init__(self, filepath: str, name_prefix: str = "instance",
                 max_instances: Optional[int] = None): # <--- Новый параметр
        """
        Инициализация ридера.

        Args:
            filepath (str): Путь к файлу датасета (.txt).
            name_prefix (str): Префикс для именования экземпляров.
            max_instances (Optional[int]): Максимальное количество экземпляров
                                           для загрузки из файла. Если None,
                                           загружаются все.
        """
        self.filepath = filepath
        self.name_prefix = name_prefix
        # Проверяем, что max_instances - положительное число или None
        if max_instances is not None and max_instances <= 0:
            print("Warning: max_instances must be positive. Loading all instances.")
            self.max_instances = None
        else:
            self.max_instances = max_instances
        self.instances: List[ConcordeTSPInstance] = []
        self._load_data()

    def _load_data(self):
        """Загружает и парсит данные из файла с отображением прогресса и лимитом."""
        try:
            with open(self.filepath, "r") as f:
                # Читаем все строки, чтобы правильно рассчитать total для tqdm
                lines = f.readlines()
        except FileNotFoundError:
            print(f"Error: Dataset file not found at {self.filepath}")
            self.instances = []
            return

        if not lines:
            print(f"Warning: File is empty: {self.filepath}")
            self.instances = []
            return

        filename_short = os.path.basename(self.filepath)
        print(f"Parsing file: {filename_short}")

        # Определяем, сколько строк будем реально обрабатывать
        num_lines_to_process = len(lines)
        if self.max_instances is not None:
            num_lines_to_process = min(len(lines), self.max_instances)
            print(f"Loading at most {self.max_instances} instances.")

        # Настраиваем tqdm
        pbar = tqdm(enumerate(lines),
                    desc=f"Parsing {filename_short}",
                    total=num_lines_to_process, # Total - количество целевых экземпляров
                    unit=" lines",
                    ncols=1000,
                    ascii=True)

        for line_num, line in pbar:
            # ---> Проверка лимита <---
            if self.max_instances is not None and len(self.instances) >= self.max_instances:
                pbar.close() # Закрываем прогресс-бар, т.к. достигли лимита
                print(f"\nReached max_instances limit ({self.max_instances}). Stopping parsing.")
                break # Выходим из цикла

            line = line.strip()
            if not line: continue

            parts = line.split()
            try:
                output_idx = parts.index('output')
            except ValueError:
                continue

            # --- Координаты ---
            num_coord_parts = output_idx
            if num_coord_parts % 2 != 0: continue
            num_nodes = num_coord_parts // 2
            if num_nodes <= 1: continue

            try:
                coords_flat = [float(p) for p in parts[:num_coord_parts]]
                coordinates = np.array(coords_flat).reshape(num_nodes, 2)
            except (ValueError, IndexError):
                continue

            # --- Матрица расстояний ---
            dist_matrix = squareform(pdist(coordinates, metric='euclidean'))

            # --- Оптимальный тур ---
            tour_parts = parts[output_idx + 1:]
            if tour_parts and tour_parts[-1] == '-1': tour_parts = tour_parts[:-1]

            try:
                optimal_tour_indices = [int(node) - 1 for node in tour_parts]
            except ValueError: continue

            if len(optimal_tour_indices) == num_nodes + 1 and optimal_tour_indices[0] == optimal_tour_indices[-1]:
                optimal_tour_0based = optimal_tour_indices[:-1]
            elif len(optimal_tour_indices) == num_nodes:
                 optimal_tour_0based = optimal_tour_indices
            else: continue

            if len(set(optimal_tour_0based)) != num_nodes or len(optimal_tour_0based) != num_nodes: continue

            # --- Оптимальная длина ---
            optimal_length = ConcordeTSPInstance.calculate_tour_len(optimal_tour_0based, dist_matrix)
            if optimal_length == float('inf'): continue

            # --- Создание экземпляра ---
            instance_name = f"{self.name_prefix}_{line_num+1}"
            instance = ConcordeTSPInstance(
                name=instance_name, coordinates=coordinates, dist_matrix=dist_matrix,
                optimal_tour=optimal_tour_0based, optimal_length=optimal_length
            )
            self.instances.append(instance)
            # Обновляем счетчик в tqdm только после успешного добавления экземпляра
            pbar.update(1) # Обновляем на 1 успешно обработанный

        # Закрываем pbar, если цикл завершился естественным образом
        if not pbar.disable:
             pbar.close()

        # Корректируем total для pbar, если обработали меньше строк, чем ожидали
        # (это может случиться, если многие строки были пропущены из-за ошибок)
        # Это нужно, чтобы прогресс-бар дошел до 100%, если лимит не был достигнут явно
        # pbar.total = len(self.instances)
        # pbar.refresh() # Обновить отображение

        print(f"\nFinished parsing. Loaded {len(self.instances)} instances from {self.filepath}")


    def __len__(self) -> int:
        """Возвращает количество загруженных экземпляров."""
        return len(self.instances)

    def __getitem__(self, index: int) -> ConcordeTSPInstance:
        """Возвращает экземпляр по индексу."""
        if 0 <= index < len(self.instances):
            return self.instances[index]
        else:
            raise IndexError("Index out of range")

    def get_iterator(self, shuffle_data: bool = False) -> Iterator[ConcordeTSPInstance]:
        """Возвращает итератор по экземплярам (опционально перемешанным)."""
        instance_list = self.instances
        if shuffle_data:
            # Используем копию, чтобы не изменять оригинальный порядок в ридере
            instance_list = shuffle(list(self.instances))
        return iter(instance_list)

## Утилитарный бенчмарк класс для сравнения методов

In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
from pulp import LpVariable, LpProblem, lpSum, LpMinimize, value
import networkx as nx
from scipy.spatial.distance import pdist, squareform

class TSPComparison:
    """
    Utilities for comparing TSP solution methods
    """
    @staticmethod
    def pure_branch_cut_solve(nodes_coord):
        """Pure Branch & Cut TSP solver without GNN guidance.

        Args:
            nodes_coord: Coordinates of nodes (num_nodes, 2)

        Returns:
            tour: Optimal tour as list of node indices
            tour_length: Length of optimal tour
            solve_time: Time taken to solve in seconds
        """
        num_nodes = nodes_coord.shape[0]

        # Get distance matrix
        dist_matrix = squareform(pdist(nodes_coord, metric='euclidean'))

        # Start timing
        start_time = time.time()

        # Create ILP model
        model = LpProblem("TSP_PureBranchCut", LpMinimize)

        # Create edge variables - x[i,j] = 1 if edge (i,j) is used
        x = {}
        for i in range(num_nodes):
            for j in range(i+1, num_nodes):  # Only need one direction for undirected graph
                x[i, j] = LpVariable(f'x_{i}_{j}', cat='Binary')

        # Helper function to access edge variable in either direction
        def get_edge_var(i, j):
            return x[min(i, j), max(i, j)]

        # Objective: minimize total distance
        model += lpSum([dist_matrix[i][j] * get_edge_var(i, j) for i in range(num_nodes) for j in range(i+1, num_nodes)])

        # Constraints: each node must have exactly two edges
        for i in range(num_nodes):
            model += lpSum([get_edge_var(i, j) for j in range(num_nodes) if j != i]) == 2

        # Solve with built-in Branch & Cut
        model.solve()

        # Extract solution
        edges = []
        for i in range(num_nodes):
            for j in range(i+1, num_nodes):
                if value(get_edge_var(i, j)) > 0.5:
                    edges.append((i, j))

        # Build tour from edges
        G = nx.Graph()
        G.add_edges_from(edges)

        # Check for subtours and add constraints until we get a valid tour
        while True:
            # Find connected components (subtours)
            components = list(nx.connected_components(G))

            if len(components) == 1:
                # Valid tour found
                break

            # Add subtour elimination constraints and resolve
            for component in components:
                if len(component) < num_nodes:
                    component = list(component)
                    model += lpSum([get_edge_var(i, j) for i in component for j in component if i < j]) <= len(component) - 1

            # Resolve
            model.solve()

            # Update graph
            G.clear()
            edges = []
            for i in range(num_nodes):
                for j in range(i+1, num_nodes):
                    if value(get_edge_var(i, j)) > 0.5:
                        edges.append((i, j))
            G.add_edges_from(edges)

        # End timing
        solve_time = time.time() - start_time

        # Convert edge list to tour
        tour = []
        current = 0  # Start from node 0
        tour.append(current)

        for _ in range(num_nodes-1):
            neighbors = list(G.neighbors(current))
            # Remove already visited nodes
            unvisited = [n for n in neighbors if n not in tour]
            if not unvisited:
                break
            current = unvisited[0]
            tour.append(current)

        # Calculate tour length
        tour_length = 0
        for i in range(len(tour)):
            j = (i + 1) % len(tour)
            tour_length += dist_matrix[tour[i]][tour[j]]

        return tour, tour_length, solve_time, edges

    @staticmethod
    def pure_gnn_beam_search_solve(nodes_coord: np.ndarray,
                                     gnn_model, config: Dict, beam_size: int,
                                     dtypeFloat, dtypeLong
                                     ) -> Tuple[Optional[List[int]], float, float]:
        """
        Решает TSP, используя только GNN и Beam Search (выбирая кратчайший тур из луча).

        Args:
            nodes_coord: Координаты узлов (num_nodes, 2).
            gnn_model: Обученная модель GNN (nn.Module).
            config: Словарь конфигурации GNN.
            beam_size: Размер луча для Beam Search.
            dtypeFloat: Тип для float тензоров.
            dtypeLong: Тип для long тензоров.

        Returns:
            tour: Найденный тур (список индексов 0-based) или None при ошибке.
            tour_length: Длина найденного тура (или inf).
            solve_time: Время выполнения (секунды).
        """
        start_time = time.time()
        num_nodes = nodes_coord.shape[0]
        tour = None
        tour_length = float('inf')

        try:
            # 1. Получаем предсказания GNN (y_pred_edges)
            # Подготовка входов для GNN (аналогично get_edge_probabilities)
            dist_matrix = squareform(pdist(nodes_coord, metric='euclidean'))
            x_edges_np = np.ones((1, num_nodes, num_nodes))
            np.fill_diagonal(x_edges_np[0], 2)
            x_edges_values_np = np.expand_dims(dist_matrix, 0)
            x_nodes_np = np.ones((1, num_nodes))
            x_nodes_coord_np = np.expand_dims(nodes_coord, 0)

            # Конвертация в тензоры
            x_edges = torch.LongTensor(x_edges_np).type(dtypeLong)
            x_edges_values = torch.FloatTensor(x_edges_values_np).type(dtypeFloat)
            x_nodes = torch.LongTensor(x_nodes_np).type(dtypeLong)
            x_nodes_coord_tensor = torch.FloatTensor(x_nodes_coord_np).type(dtypeFloat) # Переименовал
            
            device = x_nodes_coord_tensor.device # Определяем device
            batch_size = 1

            # ---> ДОБАВЛЕНИЕ: Создание фиктивных y_edges и edge_cw <---
            dummy_y_edges = torch.zeros(batch_size, num_nodes, num_nodes, dtype=torch.long, device=device)
            dummy_edge_cw = torch.ones(2, dtype=torch.float, device=device)

            # Перемещаем на GPU, если доступно
            if torch.cuda.is_available():
                x_edges = x_edges.cuda()
                x_edges_values = x_edges_values.cuda()
                x_nodes = x_nodes.cuda()
                x_nodes_coord_tensor = x_nodes_coord_tensor.cuda()

            # Предсказание
            gnn_model.eval()
            with torch.no_grad():
                # Убрали передачу y_edges и edge_cw, т.к. loss не нужен
                result = gnn_model(x_edges, x_edges_values, x_nodes, x_nodes_coord_tensor,
                                   dummy_y_edges, dummy_edge_cw)
                if isinstance(result, tuple): y_pred_edges = result[0]
                else: y_pred_edges = result

            # 2. Запускаем Beam Search (выбираем кратчайший)
            # beamsearch_tour_nodes_shortest ожидает y_pred_edges и x_edges_values
            predicted_tours_tensor = beamsearch_tour_nodes_shortest( # Или beamsearch_tour_nodes
                y_pred_edges,           # Предсказания GNN
                x_edges_values,         # Матрица расстояний (тензор)
                beam_size=beam_size, batch_size=batch_size, num_nodes=num_nodes,
                dtypeFloat=dtypeFloat, dtypeLong=dtypeLong,
                probs_type='raw', random_start=False
            )

            # 3. Получаем лучший тур из батча (он у нас один)
            if predicted_tours_tensor is not None and predicted_tours_tensor.shape[0] > 0:
                tour_tensor = predicted_tours_tensor[0] # Берем первый (и единственный) тур
                tour = tour_tensor.cpu().numpy().tolist() # Конвертируем в список Python
                # Проверяем валидность тура
                if is_valid_tour(tour, num_nodes):
                    # Рассчитываем длину найденного тура
                     tour_length = ConcordeTSPInstance.calculate_tour_len(tour, dist_matrix)
                else:
                     print("Warning: Beam Search produced an invalid tour.")
                     tour = None # Сбрасываем тур, если он невалидный
                     tour_length = float('inf')
            else:
                 print("Warning: Beam Search did not return a tour.")

        except Exception as e:
            print(f"Error during Pure GNN + Beam Search solve: {e}")
            tour = None
            tour_length = float('inf')

        solve_time = time.time() - start_time
        return tour, tour_length, solve_time

    @staticmethod
    def visualize_tours(nodes_coord, tour_pure, tour_gnn, title="TSP Tours Comparison"):
        """Visualize and compare two TSP tours.

        Args:
            nodes_coord: Node coordinates (num_nodes, 2)
            tour_pure: Tour from pure Branch & Cut
            tour_gnn: Tour from GNN-guided Branch & Cut
            title: Plot title
        """
        plt.figure(figsize=(12, 6))

        # Plot pure Branch & Cut tour
        plt.subplot(1, 2, 1)
        plt.scatter(nodes_coord[:, 0], nodes_coord[:, 1], c='blue', s=50)

        # Connect tour points
        for i in range(len(tour_pure)):
            j = (i + 1) % len(tour_pure)
            plt.plot([nodes_coord[tour_pure[i], 0], nodes_coord[tour_pure[j], 0]],
                     [nodes_coord[tour_pure[i], 1], nodes_coord[tour_pure[j], 1]], 'r-')

        # Number the nodes
        for i, (x, y) in enumerate(nodes_coord):
            plt.text(x, y, str(i), fontsize=12)

        plt.title("Pure Branch & Cut Tour")

        # Plot GNN-guided Branch & Cut tour
        plt.subplot(1, 2, 2)
        plt.scatter(nodes_coord[:, 0], nodes_coord[:, 1], c='blue', s=50)

        # Connect tour points
        for i in range(len(tour_gnn)):
            j = (i + 1) % len(tour_gnn)
            plt.plot([nodes_coord[tour_gnn[i], 0], nodes_coord[tour_gnn[j], 0]],
                     [nodes_coord[tour_gnn[i], 1], nodes_coord[tour_gnn[j], 1]], 'r-')

        # Number the nodes
        for i, (x, y) in enumerate(nodes_coord):
            plt.text(x, y, str(i), fontsize=12)

        plt.title("GNN-guided Branch & Cut Tour")

        plt.suptitle(title)
        plt.tight_layout()
        plt.show()

    @staticmethod
    def compare_methods(
                        # Модели GNN
                        gnn_solver_v1,
                        gnn_solver_v2,
                        gnn_model_v1: Optional[nn.Module],
                        gnn_model_v2: Optional[nn.Module],
                        config: Dict, # Общий конфиг GNN
                        test_instances: List[np.ndarray],
                        # Параметры для Pure GNN+BS
                        beam_size_gnn_pure: int = 10,
                        metrics=['tour_length', 'solving_time']
                        ) -> Dict[str, Dict]:
        """
        Сравнивает 5 методов:
        1. Pure B&C
        2. GNN+BS (v1)
        3. GNN+BS (v2)
        4. GNN+B&C (v1)
        5. GNN+B&C (v2)

        Args:
            gnn_model_v1: Обученная модель GNN v1 (Transformer).
            gnn_model_v2: Обученная модель GNN v2 (Linear).
            config: Словарь конфигурации GNN.
            test_instances: Список массивов координат узлов.
            gnn_bc_params_v1: Словарь параметров {'threshold', 'fixing_percentage', 'elimination_percentage'} для гибрида с v1.
            gnn_bc_params_v2: Словарь параметров для гибрида с v2.
            beam_size_gnn_pure: Размер луча для Pure GNN + Beam Search.
            metrics: Метрики для сравнения.

        Returns:
            Словарь с результатами для 5 методов.
        """
        results = {
            'pure_bc':         {'tour_lengths': [], 'solving_times': [], 'tour': []},
            'gnn_bs_v1':       {'tour_lengths': [], 'solving_times': [], 'tour': []}, # Модель 1 + BS
            'gnn_bs_v2':       {'tour_lengths': [], 'solving_times': [], 'tour': []}, # Модель 2 + BS
            'gnn_bc_v1':       {'tour_lengths': [], 'solving_times': [], 'tour': []}, # Модель 1 + B&C
            'gnn_bc_v2':       {'tour_lengths': [], 'solving_times': [], 'tour': []}, # Модель 2 + B&C
            'simulated_annealing_2opt':         {'tour_lengths': [], 'solving_times': [], 'tour': []},
        }

        # Типы тензоров
        dtypeFloat = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
        dtypeLong = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor

        for i, nodes_coord in enumerate(test_instances):
            print(f"\n===== Solving instance {i+1}/{len(test_instances)} =====")

            # --- 1. Pure Branch & Cut ---
            print("  Running Pure B&C...")
            tour_pure, tour_length_pure, pure_time, _ = TSPComparison.pure_branch_cut_solve(nodes_coord)
            results['pure_bc']['tour_lengths'].append(tour_length_pure)
            results['pure_bc']['solving_times'].append(pure_time)
            results['pure_bc']['tour'].append(tour_pure)
            print(f"    Pure B&C: Length={tour_length_pure:.4f}, Time={pure_time:.4f}s")

            # --- 2. Pure GNN + Beam Search (Модель 1) ---
            if gnn_model_v1:
                print("  Running Pure GNN + Beam Search (Model 1 - Transformer)...")
                tour_gnn_bs_v1, tour_length_gnn_bs_v1, gnn_bs_time_v1 = TSPComparison.pure_gnn_beam_search_solve(
                    nodes_coord, gnn_model_v1, config, beam_size_gnn_pure, dtypeFloat, dtypeLong
                )
                results['gnn_bs_v1']['tour_lengths'].append(tour_length_gnn_bs_v1)
                results['gnn_bs_v1']['solving_times'].append(gnn_bs_time_v1)
                results['gnn_bs_v1']['tour'].append(tour_gnn_bs_v1)
                print(f"    GNN+BS(v1): Length={tour_length_gnn_bs_v1:.4f}, Time={gnn_bs_time_v1:.4f}s")
            else:
                print("  Skipping Pure GNN + BS (Model 1 not loaded).")
                results['gnn_bs_v1']['tour_lengths'].append(float('inf'))
                results['gnn_bs_v1']['solving_times'].append(float('inf'))
                results['gnn_bs_v1']['tour'].append(None)

            # --- 3. Pure GNN + Beam Search (Модель 2) ---
            if gnn_model_v2:
                print("  Running Pure GNN + Beam Search (Model 2 - Linear)...")
                tour_gnn_bs_v2, tour_length_gnn_bs_v2, gnn_bs_time_v2 = TSPComparison.pure_gnn_beam_search_solve(
                    nodes_coord, gnn_model_v2, config, beam_size_gnn_pure, dtypeFloat, dtypeLong
                )
                results['gnn_bs_v2']['tour_lengths'].append(tour_length_gnn_bs_v2)
                results['gnn_bs_v2']['solving_times'].append(gnn_bs_time_v2)
                results['gnn_bs_v2']['tour'].append(tour_gnn_bs_v2)
                print(f"    GNN+BS(v2): Length={tour_length_gnn_bs_v2:.4f}, Time={gnn_bs_time_v2:.4f}s")
            else:
                print("  Skipping Pure GNN + BS (Model 2 not loaded).")
                results['gnn_bs_v2']['tour_lengths'].append(float('inf'))
                results['gnn_bs_v2']['solving_times'].append(float('inf'))
                results['gnn_bs_v2']['tour'].append(None)

            # --- 4. GNN-guided Branch & Cut (Модель 1) ---
            if gnn_solver_v1:
                print("  Running GNN-guided B&C (Model 1 - Transformer)...")
                tour_gnn_bc_v1, tour_length_gnn_bc_v1, gnn_bc_time_v1 = gnn_solver_v1.solve_tsp(nodes_coord)
                results['gnn_bc_v1']['tour_lengths'].append(tour_length_gnn_bc_v1)
                results['gnn_bc_v1']['solving_times'].append(gnn_bc_time_v1)
                results['gnn_bc_v1']['tour'].append(tour_gnn_bc_v1)
                print(f"    GNN+B&C(v1): Length={tour_length_gnn_bc_v1:.4f}, Time={gnn_bc_time_v1:.4f}s")
            else:
                print("  Skipping GNN-guided B&C (Model 1 not loaded or solver init failed).")
                results['gnn_bc_v1']['tour_lengths'].append(float('inf'))
                results['gnn_bc_v1']['solving_times'].append(float('inf'))
                results['gnn_bc_v1']['tour'].append(None)


            # --- 5. GNN-guided Branch & Cut (Модель 2) ---
            if gnn_solver_v2:
                print("  Running GNN-guided B&C (Model 2 - Linear)...")
                tour_gnn_bc_v2, tour_length_gnn_bc_v2, gnn_bc_time_v2 = gnn_solver_v2.solve_tsp(nodes_coord)
                results['gnn_bc_v2']['tour_lengths'].append(tour_length_gnn_bc_v2)
                results['gnn_bc_v2']['solving_times'].append(gnn_bc_time_v2)
                results['gnn_bc_v2']['tour'].append(tour_gnn_bc_v2)
                print(f"    GNN+B&C(v2): Length={tour_length_gnn_bc_v2:.4f}, Time={gnn_bc_time_v2:.4f}s")
            else:
                print("  Skipping GNN-guided B&C (Model 2 not loaded or solver init failed).")
                results['gnn_bc_v2']['tour_lengths'].append(float('inf'))
                results['gnn_bc_v2']['solving_times'].append(float('inf'))
                results['gnn_bc_v2']['tour'].append(None)
            
            # --- 6. Simulated Annealing + 2-opt (outer) ---
            print("  Running Simulated Annealing 2-opt (outer)...")
            distance_matrix = squareform(pdist(nodes_coord, metric='euclidean'))
            simulated_annealing_2opt = TSP_2opt_SA_outer(distance_matrix)
            
            start = time.time()
            tour_sa_2opt, tour_length_sa_2opt = simulated_annealing_2opt.solve()
            time_sa_2opt = time.time() - start
            results['simulated_annealing_2opt']['tour_lengths'].append(tour_length_sa_2opt)
            results['simulated_annealing_2opt']['solving_times'].append(time_sa_2opt)
            results['simulated_annealing_2opt']['tour'].append(tour_sa_2opt)
            print(f"    Simulated Annealing 2-opt (outer): Length={tour_length_sa_2opt:.4f}, Time={time_sa_2opt:.4f}s")

        # Вычисление средних значений
        print("\nCalculating average results...")
        for method in list(results.keys()):
            # Инициализируем средние значения по умолчанию как inf
            results[method]['avg_tour_lengths'] = float('inf')
            results[method]['avg_solving_times'] = float('inf')

            metrics_to_average = ['tour_lengths', 'solving_times']
            for metric in metrics_to_average:
                 # Проверяем, есть ли ключ и данные в нем
                 if results[method].get(metric):
                     # Фильтруем None и inf значения
                     valid_values = [v for v in results[method][metric] if v is not None and v != float('inf')]
                     if valid_values: # Если есть хотя бы одно валидное значение
                         results[method][f'avg_{metric}'] = np.mean(valid_values)
                     # else: оставляем значение inf по умолчанию
                 # else: оставляем значение inf по умолчанию
        # ---> КОНЕЦ ДОБАВЛЕННОГО КОДА УСРЕДНЕНИЯ <---

        return results

    @staticmethod
    def plot_comparison_results(results: Dict[Tuple[str, str], Dict[str, Any]]):
        """
        Plot comparison results for multiple methods against Pure B&C.

        Args:
            results: Dictionary mapping (instance_name, solver_name) to benchmark results.
                     Assumes structure includes keys like 'pure_bc', 'gnn_bs_v1', etc.
                     and sub-keys 'avg_tour_lengths', 'avg_solving_times'.
        """
        # --- Извлечение средних данных ---
        method_names = ['Pure B&C', 'GNN+BS(v1)', 'GNN+BS(v2)', 'GNN+B&C(v1)', 'GNN+B&C(v2)', "Simulated Annealing + 2-opt (outer)"]
        method_keys = ['pure_bc', 'gnn_bs_v1', 'gnn_bs_v2', 'gnn_bc_v1', 'gnn_bc_v2', "simulated_annealing_2opt"]

        avg_lengths = []
        avg_times = []

        for key in method_keys:
            avg_len = results.get(key, {}).get('avg_tour_lengths', float('inf'))
            avg_time = results.get(key, {}).get('avg_solving_times', float('inf'))

            # Заменяем inf на NaN для корректной отрисовки (или можно оставить inf и matplotlib их проигнорирует)
            avg_lengths.append(avg_len if avg_len != float('inf') else np.nan)
            avg_times.append(avg_time if avg_time != float('inf') else np.nan)

        # --- Построение графиков ---
        plt.figure(figsize=(16, 7)) # Немного шире

        # График длин
        plt.subplot(1, 2, 1)
        bars1 = plt.bar(method_names, avg_lengths, color=['blue', 'orange', 'green', 'red', 'purple'])
        plt.title('Average Tour Length Comparison')
        plt.ylabel('Average Length')
        plt.xticks(rotation=25, ha='right') # Поворот подписей для читаемости
        # Добавляем значения над столбцами
        for bar in bars1:
             yval = bar.get_height()
             if not np.isnan(yval):
                 plt.text(bar.get_x() + bar.get_width()/2.0, yval, f'{yval:.2f}', va='bottom', ha='center') # Показываем значение

        # График времени
        plt.subplot(1, 2, 2)
        bars2 = plt.bar(method_names, avg_times, color=['blue', 'orange', 'green', 'red', 'purple'])
        plt.title('Average Solving Time Comparison')
        plt.ylabel('Average Time (s)')
        plt.xticks(rotation=25, ha='right')
        # Добавляем значения над столбцами
        for bar in bars2:
             yval = bar.get_height()
             if not np.isnan(yval):
                 plt.text(bar.get_x() + bar.get_width()/2.0, yval, f'{yval:.3f}', va='bottom', ha='center') # Больше знаков для времени

        plt.tight_layout()
        plt.show()

## Пример использования

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import time
import random
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import textwrap # Для переноса длинных названий методов
import csv
import os
import json

# Инициализация и обучение модели GNN
def train_gnn_model(config, num_epochs=25, version=1):
    """Train GNN model for TSP with validation and testing.

    Args:
        config: Configuration dictionary
        num_epochs: Number of epochs to train

    Returns:
        net: Trained GNN model
        losses: Dictionary containing training, validation and test losses
    """
    # Initialize model
    if version == 1:
        net = nn.DataParallel(ResidualGatedGCNModel(config, torch.cuda.FloatTensor, torch.cuda.LongTensor))
    else:
        net = nn.DataParallel(ResidualGatedGCNModel_v2(config, torch.cuda.FloatTensor, torch.cuda.LongTensor))
    net.cuda()

    # Setup optimizer
    optimizer = torch.optim.Adam(net.parameters(), lr=variables['learning_rate'], weight_decay=1e-5)
    val_loss_old = None

    # Prepare for tracking losses
    train_losses = []
    val_losses = []
    test_losses = []

    print("Starting training...")
    for epoch in range(num_epochs):
        # Train one epoch
        train_time, train_loss = train_one_epoch(net, optimizer, config)
        train_losses.append(train_loss)

        print(f"Epoch: {epoch}, Train Loss: {train_loss:.4f}, Time: {train_time:.2f}s")

        # Validation phase
        if epoch % config["val_every"] == 0 or epoch == num_epochs-1:
            val_time, val_loss = test(net, config, mode='val')
            val_losses.append(val_loss)
            print(f"Epoch: {epoch}, Val Loss: {val_loss:.4f}, Time: {val_time:.2f}s")

            # Update learning rate based on validation performance
            if val_loss_old is not None and val_loss > 0.99 * val_loss_old:
                config["learning_rate"] /= config["decay_rate"]
                optimizer = update_learning_rate(optimizer, config["learning_rate"])
                print(f"Learning rate updated to: {config['learning_rate']:.6f}")

            val_loss_old = val_loss  # Update old validation loss

        # Testing phase
        if epoch % config["test_every"] == 0 or epoch == num_epochs-1:
            test_time, test_loss = test(net, config, mode='test')
            test_losses.append(test_loss)
            print(f"Epoch: {epoch}, Test Loss: {test_loss:.4f}, Time: {test_time:.2f}s\n")
        
        if epoch % config["save_every"] == 0 or epoch == num_epochs-1:
            _, test_loss_save = test(net, config, mode='test')
            save_model(net, f"./tsp_gnn_model_epoch_{epoch}_test_loss_{test_loss_save}.pt")
            print(f"Epoch: {epoch}, Test Loss: {test_loss_save:.4f}, Model Saved...") 

    print("Training complete!")

    # Return model and losses for further analysis
    losses = {
        'train': train_losses,
        'validation': val_losses,
        'test': test_losses
    }

    return net, losses

# Генерация тестовых примеров
def generate_test_instances(num_instances=10, num_nodes=20):
    """Generate random TSP test instances.

    Args:
        num_instances: Number of instances to generate
        num_nodes: Number of nodes per instance

    Returns:
        instances: List of node coordinate arrays
    """
    instances = []
    for _ in range(num_instances):
        # Generate random node coordinates in [0, 1] x [0, 1]
        nodes_coord = np.random.rand(num_nodes, 2)
        instances.append(nodes_coord)
    return instances

# --- Константы и Конфигурация ---
NUM_TEST_INSTANCES = 1 # Уменьшим для скорости отладки
NUM_NODES = 500
NUM_RUNS = 1
BEAM_SIZE = 5 # Размер луча для Beam Search #1, 10

GNN_BC_PARAMS = {
    "threshold": 0.4, # 0.0, 0.3, 0.5, 0.0, 0.5, 0.9
    "fixing_percentage": 0.0,
    "elimination_percentage": 0.60 # 0.45, 0.45, 0.45, 0.7, 0.7, 0.7
}

RESULTS_FILENAME = f"hyperparameter_tuning_results_bs{str(BEAM_SIZE)}_ep{str(GNN_BC_PARAMS["elimination_percentage"])}_th{str(GNN_BC_PARAMS["threshold"])}.csv"
FILE_EXISTS = os.path.exists(RESULTS_FILENAME)

# Открываем файл для добавления ('a') или создаем новый ('w') с заголовком
csv_file = open(RESULTS_FILENAME, 'a', newline='', encoding='utf-8')
csv_writer = csv.writer(csv_file)

if not FILE_EXISTS:
    header = [
        "Beam Size", "Elimination %", "Threshold", # Параметры GNN
        "Instance", "Dimension",                   # Параметры инстанса
        "Method",                                  # Название метода
        "Avg Length", "Avg Time (s)",              # Метрики
        "Length Diff", "Speedup vs Pure"           # Сравнение с Pure B&C
    ]
    csv_writer.writerow(header)


# Пути к моделям
MODEL_PATH_V1 = "./tsp_gnn_model_naive.pt" # Модель с TransformerConv
MODEL_PATH_V2 = "./tsp_gnn_model_epoch_499_test_loss_0.12400513887405396.pt" # Модель без TransformerConv

# Initialize configuration
config = {
    'train_filepath': './tsp_data/tsp20_train_concorde.txt',
    'val_filepath': './tsp_data/tsp20_val_concorde.txt',
    'test_filepath': './tsp_data/tsp20_test_concorde.txt',
    'num_nodes': 20,
    'num_neighbors': -1,
    'node_dim': 2,
    'voc_nodes_in': 2,
    'voc_nodes_out': 2,
    'voc_edges_in': 3,
    'voc_edges_out': 2,
    'hidden_dim': 300,
    'num_layers': 5,
    'mlp_layers': 2,
    'aggregation': 'mean',
    'max_epochs': 50,
    'batches_per_epoch': 256,  # Reduced for example
    'accumulation_steps': 1,
    'learning_rate': 0.001,
    'decay_rate': 1.01,
    'val_every': 3,
    'test_every': 3,
    'save_every': 50,
    'batch_size': 64

}

# Типы тензоров
dtypeFloat = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
dtypeLong = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor

# --- Загрузка Моделей ---
print("Loading GNN Models...")
net_v1 = None
net_v2 = None

if os.path.exists(MODEL_PATH_V1):
    net_v1 = load_model(ResidualGatedGCNModel, config, MODEL_PATH_V1, dtypeFloat, dtypeLong)
    print(f"Model 1 ({MODEL_PATH_V1}) loaded successfully.")
else:
    print(f"Model file not found: {MODEL_PATH_V1}")
    # Опционально: Запустить обучение V1, если файл не найден
    print("Training Model V1...")
    net_v1, losses_v1 = train_gnn_model(config, num_epochs=50)
    save_model(net_v1, MODEL_PATH_V1)

if os.path.exists(MODEL_PATH_V2):
    net_v2 = load_model(ResidualGatedGCNModel_v2, config, MODEL_PATH_V2, dtypeFloat, dtypeLong)
    print(f"Model 2 ({MODEL_PATH_V2}) loaded successfully.")
else:
    print(f"Model file not found: {MODEL_PATH_V2}")
    # Опционально: Запустить обучение V2, если файл не найден
    print("Training Model V2...")
    net_v2, losses_v2 = train_gnn_model(config, num_epochs=500)
    save_model(net_v2, MODEL_PATH_V2)



# --- Генерация Тестовых Экземпляров ---
test_instances = generate_test_instances(num_instances=NUM_TEST_INSTANCES, num_nodes=NUM_NODES)

# --- Запуск Сравнения ---
print("\nRunning Comparison: Pure B&C vs GNN+BS(v1) vs GNN+BS(v2) vs GNN+B&C(v1) vs GNN+B&C(v2)...")
current_gnn_bc_params_v1 = GNN_BC_PARAMS.copy()
current_gnn_bc_params_v2 = GNN_BC_PARAMS.copy()
# current_gnn_bc_params_v1["elimination_percentage"] = elimination_perc
# current_gnn_bc_params_v1["threshold"] = threshold
# current_gnn_bc_params_v2["elimination_percentage"] = elimination_perc
# current_gnn_bc_params_v2["threshold"] = threshold

print(f"\n--- Testing Params: Beam={BEAM_SIZE}, Elim={current_gnn_bc_params_v1["elimination_percentage"]:.1f}, Thresh={current_gnn_bc_params_v1["threshold"]:.1f} ---")

# --- Инициализация Солверов ---
gnn_solver_v1_bc = None
if net_v1:
    gnn_solver_v1_bc = GNNBranchCutSolver(gnn_model=net_v1, **current_gnn_bc_params_v1)

gnn_solver_v2_bc = None
if net_v2:
    gnn_solver_v2_bc = GNNBranchCutSolver(gnn_model=net_v2, **current_gnn_bc_params_v2)

for run_num in range(NUM_RUNS):
    all_results = {}

    # Передаем None для моделей/солверов, если они не загрузились
    all_results = TSPComparison.compare_methods(
        # Передаем актуальные солверы
        gnn_solver_v1=gnn_solver_v1_bc,
        gnn_solver_v2=gnn_solver_v2_bc,
        gnn_model_v1=net_v1,
        gnn_model_v2=net_v2,
        config=config,
        test_instances=test_instances, # Используем ОДИН И ТОТ ЖЕ НАБОР
        # Передаем текущий beam_size
        beam_size_gnn_pure=BEAM_SIZE
    )

                # --- Логирование результатов этой итерации в CSV ---
    # if all_results:
    #     # Получаем базовые результаты Pure B&C (предполагаем, что они не меняются от параметров GNN)
    #     pure_res = all_results.get('pure_bc', {})
    #     pure_avg_len = pure_res.get('avg_tour_lengths', float('inf'))
    #     pure_avg_time = pure_res.get('avg_solving_times', float('inf'))

    #     methods_to_log = {
    #         "GNN+BS (v1)": 'gnn_bs_v1', "GNN+BS (v2)": 'gnn_bs_v2',
    #         "GNN+B&C (v1)": 'gnn_bc_v1', "GNN+B&C (v2)": 'gnn_bc_v2'
    #     }

    #     # Извлекаем размерность (предполагаем, что она одинакова для всех инстансов в test_instances)
    #     dimension = NUM_NODES # Берем из конфига, если test_instances могут быть разного размера

    #     for name, key in methods_to_log.items():
    #         if key in all_results and all_results[key]:
    #             res = all_results[key]
    #             avg_len = res.get('avg_tour_lengths', float('inf'))
    #             avg_time = res.get('avg_solving_times', float('inf'))

    #             len_diff = (avg_len - pure_avg_len) if avg_len != float('inf') and pure_avg_len != float('inf') else float('inf')
    #             speedup = (pure_avg_time / avg_time) if avg_time > 0 and pure_avg_time != float('inf') else float('inf')

    #             # Записываем строку в CSV
    #             csv_writer.writerow([
    #                 BEAM_SIZE, current_gnn_bc_params_v1["elimination_percentage"], current_gnn_bc_params_v1["threshold"],
    #                 f"Avg_{len(test_instances)}x{dimension}", dimension, # Имя инстанса теперь обобщенное
    #                 name,
    #                 avg_len, avg_time,
    #                 len_diff, speedup
    #             ])

    # csv_file.close() # Обязательно закрыть файл
    # print(f"\nHyperparameter tuning results saved to {RESULTS_FILENAME}")

    print("\n===== Overall Method Comparison =====")
    print("Method               | Avg Length     | Length Diff | Avg Time (s) | Speedup vs Pure")
    print("---------------------|----------------|-------------|--------------|-----------------")

    if all_results:
        pure_avg_len = all_results['pure_bc'].get('avg_tour_lengths', float('inf'))
        pure_avg_time = all_results['pure_bc'].get('avg_solving_times', float('inf'))
        if pure_avg_len != float('inf') and pure_avg_time != float('inf'):
            print(f"Pure B&C           | {pure_avg_len:<14.4f} | +0.0000     | {pure_avg_time:<12.4f} | 1.00x")
        else:
            print("Pure B&C results could not be calculated.")

        # Определяем порядок и ключи для вывода
        methods_to_print = {
            "GNN+BS (v1)": 'gnn_bs_v1',
            "GNN+BS (v2)": 'gnn_bs_v2',
            "GNN+B&C (v1)": 'gnn_bc_v1',
            "GNN+B&C (v2)": 'gnn_bc_v2',
            "Simulated Annealing + 2-opt (outer)": "simulated_annealing_2opt"
        }

        for name, key in methods_to_print.items():
            if key in all_results and all_results[key]: # Проверяем наличие ключа и данных
                res = all_results[key]
                avg_len = res.get('avg_tour_lengths', float('inf'))
                avg_time = res.get('avg_solving_times', float('inf'))
                # speedup = (pure_avg_time / avg_time) if avg_time > 0 and pure_avg_time != float('inf') else float('inf')
                # len_diff = (avg_len - pure_avg_len) if avg_len != float('inf') and pure_avg_len != float('inf') else float('inf')
                # print(f"{name:<20} | {avg_len:<14.4f} | {len_diff:<+11.4f} | {avg_time:<12.4f} | {speedup:.2f}x")
            else:
                print(f"{name:<20} | N/A            | N/A         | N/A          | N/A")

    print("-------------------------------------------------------------------------------")

    # --- Визуализация (Пример: Сравнение Pure B&C и GNN+BS(v1)) ---
    method_to_compare_1 = 'pure_bc'
    method_to_compare_2 = 'gnn_bs_v1'
    method_to_compare_3 = 'gnn_bs_v2'
    method_to_compare_4 = 'gnn_bc_v1'
    method_to_compare_5 = 'gnn_bc_v2'
    method_to_compare_6 = 'simulated_annealing_2opt'

    # for instance_idx_viz in range(5):
    #     if instance_idx_viz < len(test_instances) and method_to_compare_1 in all_results and method_to_compare_2 in all_results:
    #         coords_viz = test_instances[instance_idx_viz]
    #         tour_pure_viz = all_results['pure_bc']['tour'][instance_idx_viz]
    #         len_pure_viz = all_results['pure_bc']['tour_lengths'][instance_idx_viz]
    #         tour_gnn_viz = all_results['gnn_bs_v1']['tour'][instance_idx_viz]
    #         len_gnn_viz = all_results['gnn_bs_v1']['tour_lengths'][instance_idx_viz]

    #         if tour_pure_viz and tour_gnn_viz:
    #             plot_comparison(
    #                 nodes_coord=coords_viz,
    #                 found_tour=tour_gnn_viz,
    #                 found_length=len_gnn_viz,
    #                 optimal_tour=tour_pure_viz,
    #                 optimal_length=len_pure_viz,
    #                 instance_name=f"Instance {instance_idx_viz+1}",
    #                 method_name="GNN+BS (v1)"
    #             )


    # for instance_idx_viz in range(5):
    #     if instance_idx_viz < len(test_instances) and method_to_compare_1 in all_results and method_to_compare_2 in all_results:
    #         coords_viz = test_instances[instance_idx_viz]
    #         tour_pure_viz = all_results['pure_bc']['tour'][instance_idx_viz]
    #         len_pure_viz = all_results['pure_bc']['tour_lengths'][instance_idx_viz]
    #         tour_gnn_viz = all_results['gnn_bc_v1']['tour'][instance_idx_viz]
    #         len_gnn_viz = all_results['gnn_bc_v1']['tour_lengths'][instance_idx_viz]

    #         if tour_pure_viz and tour_gnn_viz:
    #             plot_comparison(
    #                 nodes_coord=coords_viz,
    #                 found_tour=tour_gnn_viz,
    #                 found_length=len_gnn_viz,
    #                 optimal_tour=tour_pure_viz,
    #                 optimal_length=len_pure_viz,
    #                 instance_name=f"Instance {instance_idx_viz+1}",
    #                 method_name="GNN+BC (v1)"
    #             )

    # for instance_idx_viz in range(5):
    #     if instance_idx_viz < len(test_instances) and method_to_compare_1 in all_results and method_to_compare_2 in all_results:
    #         coords_viz = test_instances[instance_idx_viz]
    #         tour_pure_viz = all_results['pure_bc']['tour'][instance_idx_viz]
    #         len_pure_viz = all_results['pure_bc']['tour_lengths'][instance_idx_viz]
    #         tour_gnn_viz = all_results['gnn_bc_v2']['tour'][instance_idx_viz]
    #         len_gnn_viz = all_results['gnn_bc_v2']['tour_lengths'][instance_idx_viz]

    #         if tour_pure_viz and tour_gnn_viz:
    #             plot_comparison(
    #                 nodes_coord=coords_viz,
    #                 found_tour=tour_gnn_viz,
    #                 found_length=len_gnn_viz,
    #                 optimal_tour=tour_pure_viz,
    #                 optimal_length=len_pure_viz,
    #                 instance_name=f"Instance {instance_idx_viz+1}",
    #                 method_name="GNN+BC (v2)"
    #             )

    # 5. Print results
    print("\nResults Summary:")
    print(f"Pure Branch & Cut - Avg Tour Length: {all_results['pure_bc']['avg_tour_lengths']:.4f}, Avg Time: {all_results['pure_bc'].get('avg_solving_times', float('inf')):.4f}s")
    print(f"GNN Branch & Cut (v1) - Avg Tour Length: {all_results['gnn_bc_v1']['avg_tour_lengths']:.4f}, Avg Time: {all_results['gnn_bc_v1'].get('avg_solving_times', float('inf')):.4f}s")
    print(f"GNN Branch & Cut (v2) - Avg Tour Length: {all_results['gnn_bc_v2']['avg_tour_lengths']:.4f}, Avg Time: {all_results['gnn_bc_v2'].get('avg_solving_times', float('inf')):.4f}s")
    print(f"GNN+BS (v1) - Avg Tour Length: {all_results['gnn_bs_v1']['avg_tour_lengths']:.4f}, Avg Time: {all_results['gnn_bs_v1'].get('avg_solving_times', float('inf')):.4f}s")
    print(f"GNN+BS (v2) - Avg Tour Length: {all_results['gnn_bs_v2']['avg_tour_lengths']:.4f}, Avg Time: {all_results['gnn_bs_v2'].get('avg_solving_times', float('inf')):.4f}s")
    print(f"Simulated Annealing + 2-opt (outer) - Avg Tour Length: {all_results['simulated_annealing_2opt']['avg_tour_lengths']:.4f}, Avg Time: {all_results['simulated_annealing_2opt'].get('avg_solving_times', float('inf')):.4f}s")
    
    # Calculate improvements
    #time_improvement_v1 = (all_results['pure_bc'].get('avg_solving_times', float('inf')) - all_results['gnn_bc_v1'].get('avg_solving_times', float('inf'))) / all_results['pure_bc'].get('avg_solving_times', float('inf')) * 100
    #time_improvement_v2 = (all_results['pure_bc'].get('avg_solving_times', float('inf')) - all_results['gnn_bc_v2'].get('avg_solving_times', float('inf'))) / all_results['pure_bc'].get('avg_solving_times', float('inf')) * 100
    #quality_diff_v1 = (all_results['pure_bc']['avg_tour_lengths'] - all_results['gnn_bc_v1']['avg_tour_lengths']) / all_results['pure_bc']['avg_tour_lengths'] * 100
    #quality_diff_v2 = (all_results['pure_bc']['avg_tour_lengths'] - all_results['gnn_bc_v2']['avg_tour_lengths']) / all_results['pure_bc']['avg_tour_lengths'] * 100

    #print(f"\nTime improvement (bc v1 vs pure bc): {time_improvement_v1:.2f}%")
    #print(f"Solution quality difference (bc v1 vs pure bc): {quality_diff_v1:.2f}% ({'better' if quality_diff_v1 >= 0 else 'worse'})")

    # print(f"\nTime improvement (bc v2 vs pure bc): {time_improvement_v2:.2f}%")
    # print(f"Solution quality difference (bc v2 vs pure bc): {quality_diff_v2:.2f}% ({'better' if quality_diff_v2 >= 0 else 'worse'})")

    # 6. Visualize comparison
    print("Results:\n", all_results)

    aggregated_results = {}
    method_keys_to_aggregate = ['pure_bc', 'gnn_bs_v1', 'gnn_bs_v2', 'gnn_bc_v1', 'gnn_bc_v2', 'simulated_annealing_2opt']

    for method_key in method_keys_to_aggregate:
        if method_key in all_results:
            # Просто копируем уже вычисленные средние
            aggregated_results[method_key] = {
                'avg_tour_lengths': all_results[method_key].get('avg_tour_lengths', float('inf')),
                'avg_solving_times': all_results[method_key].get('avg_solving_times', float('inf'))
            }
        else:
            # Если метод не тестировался, добавляем N/A (или inf)
            aggregated_results[method_key] = {
                'avg_tour_lengths': float('inf'),
                'avg_solving_times': float('inf')
            }

    if 'TSPComparison' in locals() and hasattr(TSPComparison, 'plot_comparison_results'):
        TSPComparison.plot_comparison_results(aggregated_results)
    else:
        print("Warning: TSPComparison.plot_comparison_results not found or TSPComparison not defined.")

    TSPComparison.plot_comparison_results(aggregated_results)

    # 7. Detailed visualization of one instance
    instance_idx = 0  # Choose first instance
    print(f"\nVisualizing detailed comparison for test instance {instance_idx + 1}...")

    # # Solve with pure Branch & Cut
    # tour_pure, tour_length_pure, _, _ = TSPComparison.pure_branch_cut_solve(test_instances[instance_idx])

    # # Solve with GNN-guided Branch & Cut (v1)
    # tour_gnn, tour_length_gnn, _ = gnn_solver_v1_bc.solve_tsp(test_instances[instance_idx])

    # # Visualize
    # TSPComparison.visualize_tours(
    #     test_instances[instance_idx],
    #     tour_pure,
    #     tour_gnn,
    #     f"TSP Instance Comparison - Pure: {tour_length_pure:.4f}, GNN: {tour_length_gnn:.4f}"
    # )
    
    print("\nRaw Results:", all_results)

    # # Конвертируем туры numpy/list в списки для JSON
    # def make_json_serializable(obj):
    #     if isinstance(obj, np.ndarray): return obj.tolist()
    #     if isinstance(obj, list): return [make_json_serializable(item) for item in obj]
    #     if isinstance(obj, dict): return {k: make_json_serializable(v) for k, v in obj.items()}
    #     if isinstance(obj, tuple): return tuple(make_json_serializable(item) for item in obj)
    #     # Добавьте другие типы по необходимости (Path и т.д.)
    #     if isinstance(obj, Path): return str(obj)
    #     if isinstance(obj, (int, float, str, bool)) or obj is None: return obj
    #     return repr(obj) # Запасной вариант

    # # Преобразуем ключи-кортежи в строки
    # string_key_results = {str(k): make_json_serializable(v) for k, v in all_results.items()}

    BASE_PATH = './benchmark_data/'

    try:
        with open(f"{BASE_PATH}concorde_tsp_outputs/benchmark_results_nodes_{NUM_NODES}_instances_{NUM_TEST_INSTANCES}_run_{run_num}.json", "w") as f:
            json.dump(all_results, f, indent=4)
        print("\nBenchmark results saved to benchmark_results.json")
    except Exception as e:
        print(f"\nError saving results to JSON: {e}")

# 8. Analyze GNN predictions
# print("\nAnalyzing GNN edge predictions...")
# edge_probs = gnn_solver_v1_bc.get_edge_probabilities(test_instances[instance_idx])

# # Calculate prediction accuracy
# # Create ground truth edge mask from optimal tour
# true_edges = np.zeros((config['num_nodes'], config['num_nodes']))
# for i in range(len(tour_pure)):
#     j = (i + 1) % len(tour_pure)
#     true_edges[tour_pure[i], tour_pure[j]] = 1
#     true_edges[tour_pure[j], tour_pure[i]] = 1

# # Convert edge probabilities to binary predictions using threshold
# pred_edges = (edge_probs > 0.5).astype(int)

# # Calculate accuracy metrics
# true_positives = np.sum((pred_edges == 1) & (true_edges == 1))
# false_positives = np.sum((pred_edges == 1) & (true_edges == 0))
# false_negatives = np.sum((pred_edges == 0) & (true_edges == 1))
# true_negatives = np.sum((pred_edges == 0) & (true_edges == 0))

# precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
# recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
# f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

# print(f"Edge prediction precision: {precision:.4f}")
# print(f"Edge prediction recall: {recall:.4f}")
# print(f"Edge prediction F1 score: {f1_score:.4f}")

# # 9. Analyze effect of different GNN parameters
# print("\nAnalyzing effect of different GNN threshold parameters...")

# thresholds = [0.3, 0.5, 0.7, 0.9]
# fixing_percentages = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
# elimination_percentages = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

# # Create test instance
# test_instance = test_instances[0]

# # Get baseline result
# _, baseline_length, baseline_time, _ = TSPComparison.pure_branch_cut_solve(test_instance)
# print(f"Baseline - Length: {baseline_length:.4f}, Time: {baseline_time:.4f}s")

# # Test different parameters
# for threshold in thresholds:
#     for fixing_percentage in fixing_percentages:
#       for elimination_percentage in elimination_percentages:
#         # Update solver parameters
#         gnn_solver_v1_bc.threshold = threshold
#         gnn_solver_v1_bc.fixing_percentage = fixing_percentage
#         # Задаем elimination_percentage отдельно (например, фиксируем на 10%)
#         #gnn_solver.elimination_percentage = 0.1
#         gnn_solver_v1_bc.elimination_percentage = elimination_percentage # вернул обратно

#         # Solve
#         start_time = time.time()
#         _, tour_length, _ = gnn_solver_v1_bc.solve_tsp(test_instance)
#         solve_time = time.time() - start_time

#         # Print results
#         print(f"Threshold: {threshold}, Fixing: {fixing_percentage*100}%, Elimination: {elimination_percentage*100}% - Length: {tour_length:.4f}, Time: {solve_time:.4f}s")

In [None]:
# print("\nRaw Results:", all_results)

# import json
# # # Конвертируем туры numpy/list в списки для JSON
# # def make_json_serializable(obj):
# #     if isinstance(obj, np.ndarray): return obj.tolist()
# #     if isinstance(obj, list): return [make_json_serializable(item) for item in obj]
# #     if isinstance(obj, dict): return {k: make_json_serializable(v) for k, v in obj.items()}
# #     if isinstance(obj, tuple): return tuple(make_json_serializable(item) for item in obj)
# #     # Добавьте другие типы по необходимости (Path и т.д.)
# #     if isinstance(obj, Path): return str(obj)
# #     if isinstance(obj, (int, float, str, bool)) or obj is None: return obj
# #     return repr(obj) # Запасной вариант

# # # Преобразуем ключи-кортежи в строки
# # string_key_results = {str(k): make_json_serializable(v) for k, v in all_results.items()}

# BASE_PATH = './benchmark_data/'

# try:
#     with open(f"{BASE_PATH}concorde_tsp_outputs/benchmark_results_{NUM_NODES}_{NUM_TEST_INSTANCES}.json", "w") as f:
#         json.dump(string_key_results, f, indent=4)
#     print("\nBenchmark results saved to benchmark_results.json")
# except Exception as e:
#     print(f"\nError saving results to JSON: {e}")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

RESULTS_FILENAME = "hyperparameter_tuning_results_all.csv"
print(f"\n--- Analyzing Results from {RESULTS_FILENAME} ---")

# --- Загрузка данных из CSV ---
try:
    df_results = pd.read_csv(RESULTS_FILENAME)
    # Заменяем inf на NaN для удобства
    df_results.replace([np.inf, -np.inf], np.nan, inplace=True)
    # Преобразуем типы, если нужно
    df_results['Beam Size'] = pd.to_numeric(df_results['Beam Size'], errors='coerce').fillna(0).astype(int)
    df_results['Elimination %'] = pd.to_numeric(df_results['Elimination %'], errors='coerce')
    df_results['Threshold'] = pd.to_numeric(df_results['Threshold'], errors='coerce')
    df_results['Avg Length'] = pd.to_numeric(df_results['Avg Length'], errors='coerce')
    df_results['Avg Time (s)'] = pd.to_numeric(df_results['Avg Time (s)'], errors='coerce')
    df_results['Length Diff'] = pd.to_numeric(df_results['Length Diff'], errors='coerce')
    df_results['Speedup vs Pure'] = pd.to_numeric(df_results['Speedup vs Pure'], errors='coerce')

    print("Loaded DataFrame:")
    print(df_results.head())
except FileNotFoundError:
    print(f"Error: Results file '{RESULTS_FILENAME}' not found.")
    exit()
except Exception as e:
    print(f"Error loading or processing results file: {e}")
    exit()

if df_results.empty:
    print("No data to analyze.")
else:
    # --- Считаем СРЕДНЕЕ время Pure B&C по всем записям ---
    # pure_bc_times = df_results[df_results['Method'] == 'Pure B&C']['Avg Time (s)'].dropna()
    # avg_pure_bc_time = pure_bc_times.mean() if not pure_bc_times.empty else np.nan
    avg_pure_bc_time = 6.9507
    print(f"\nAverage Pure B&C time across all runs: {avg_pure_bc_time:.4f}s")

    # --- Пересчитываем Speedup, используя СРЕДНЕЕ время Pure B&C ---
    if not np.isnan(avg_pure_bc_time) and avg_pure_bc_time > 0:
        # Создаем новую колонку или перезаписываем старую
        df_results['Speedup vs Avg Pure'] = df_results.apply(
            lambda row: avg_pure_bc_time / row['Avg Time (s)'] if pd.notna(row['Avg Time (s)']) and row['Avg Time (s)'] > 0 else np.nan,
            axis=1
        )
        # Удаляем старую колонку Speedup, если она была
        if 'Speedup vs Pure' in df_results.columns:
             df_results.drop(columns=['Speedup vs Pure'], inplace=True)
        # Переименовываем новую колонку для консистентности
        df_results.rename(columns={'Speedup vs Avg Pure': 'Speedup vs Pure'}, inplace=True)
        print("Recalculated 'Speedup vs Pure' using average Pure B&C time.")
    else:
        print("Could not calculate average Pure B&C time, Speedup calculation skipped.")
    # --- Визуализация зависимостей ---
    sns.set_theme(style="whitegrid")

    # Выбираем методы для анализа зависимостей
    # Например, сравним GNN+BS(v2) и GNN+B&C(v2)
    methods_to_analyze = ["GNN+BS (v2)", "GNN+B&C (v2)"]
    df_filtered = df_results[df_results['Method'].isin(methods_to_analyze)]

    # --- График 1: Зависимость Длины от Beam Size (для GNN+BS) ---
    df_bs = df_filtered[df_filtered['Method'].str.contains("GNN\+BS")]
    if not df_bs.empty:
         plt.figure(figsize=(10, 6))
         sns.lineplot(data=df_bs, x='Beam Size', y='Avg Length', hue='Method', marker='o')
         plt.title('Влияние Beam Size на Качество GNN+BS')
         plt.xlabel('Beam Size (K)')
         plt.ylabel('Средняя Длина Тура')
         plt.grid(True, linestyle='--', alpha=0.7)
         plt.legend(title='Метод')
         plt.show()

    # --- График 2: Зависимость Времени от Beam Size (для GNN+BS) ---
    if not df_bs.empty:
         plt.figure(figsize=(10, 6))
         sns.lineplot(data=df_bs, x='Beam Size', y='Avg Time (s)', hue='Method', marker='o')
         plt.title('Влияние Beam Size на Время GNN+BS')
         plt.xlabel('Beam Size (K)')
         plt.ylabel('Среднее Время (с)')
         plt.yscale('log') # Время лучше смотреть в лог. масштабе
         plt.grid(True, linestyle='--', alpha=0.7)
         plt.legend(title='Метод')
         plt.show()

    # --- График 3: Зависимость Качества GNN+B&C от Elimination % (при фикс. Threshold) ---
    df_bc = df_filtered[df_filtered['Method'].str.contains("GNN\+B\&C")]
    fixed_threshold = 0.5 # Выберите порог для анализа
    df_bc_thresh = df_bc[np.isclose(df_bc['Threshold'], fixed_threshold)]
    if not df_bc_thresh.empty:
         plt.figure(figsize=(10, 6))
         sns.lineplot(data=df_bc_thresh, x='Elimination %', y='Length Diff', hue='Method', marker='o')
         plt.title(f'Влияние Elimination % на Отклонение GNN+B&C (Threshold={fixed_threshold})')
         plt.xlabel('Elimination Percentage')
         plt.ylabel('Среднее Отклонение от Оптимума (Length Diff)')
         plt.xticks(np.arange(0, 1, 0.1)) # Метки от 0 до 0.9
         plt.grid(True, linestyle='--', alpha=0.7)
         plt.legend(title='Метод')
         plt.axhline(0, color='grey', linestyle=':', linewidth=1)
         plt.show()

    # --- График 4: Зависимость Времени GNN+B&C от Elimination % (при фикс. Threshold) ---
    if not df_bc_thresh.empty:
         plt.figure(figsize=(10, 6))
         sns.lineplot(data=df_bc_thresh, x='Elimination %', y='Avg Time (s)', hue='Method', marker='o')
         plt.title(f'Влияние Elimination % на Время GNN+B&C (Threshold={fixed_threshold})')
         plt.xlabel('Elimination Percentage')
         plt.ylabel('Среднее Время (с)')
         plt.xticks(np.arange(0, 1, 0.1))
         plt.yscale('log')
         plt.grid(True, linestyle='--', alpha=0.7)
         plt.legend(title='Метод')
         plt.show()

    # --- Таблица Лучших Параметров (Пример) ---
    print("\n--- Лучшие результаты для каждого метода ---")
    # Находим строку с минимальной средней длиной для каждого метода
    best_results = df_results.loc[df_results.groupby('Method')['Avg Length'].idxmin()]
    print(best_results[['Method', 'Beam Size', 'Elimination %', 'Threshold', 'Avg Length', 'Length Diff', 'Avg Time (s)', 'Speedup vs Pure']].round(4).to_string(index=False))

    print("\n--- Самые быстрые результаты (с приемлемым качеством, например, Diff < 0.1) ---")
    fast_good_results = df_results[(df_results['Length Diff'].fillna(float('inf')) < 0.1)].sort_values('Avg Time (s)')
    print(fast_good_results[['Method', 'Beam Size', 'Elimination %', 'Threshold', 'Avg Length', 'Length Diff', 'Avg Time (s)', 'Speedup vs Pure']].round(4).to_string(index=False))

    # --- Таблица Лучших по ВРЕМЕНИ для каждого метода ---
    print("\n--- Лучшие результаты по ВРЕМЕНИ для каждого метода ---")
    # Находим строку с минимальным средним временем для каждого метода
    # Обрабатываем NaN перед поиском минимума
    df_results_valid_time = df_results.dropna(subset=['Speedup vs Pure'])
    if not df_results_valid_time.empty:
        best_time_results = df_results_valid_time.loc[df_results_valid_time.groupby('Method')['Speedup vs Pure'].idxmax()]
        print(best_time_results[['Method', 'Beam Size', 'Elimination %', 'Threshold', 'Avg Length', 'Length Diff', 'Avg Time (s)', 'Speedup vs Pure']].round(4).to_string(index=False))
    else:
        print("Нет валидных данных по времени для поиска лучших.")


    # --- Таблица Быстрых (по времени) результатов с приемлемым качеством ---
    print("\n--- Самые БЫСТРЫЕ результаты с приемлемым качеством (например, Diff < 0.1) ---")
    quality_threshold = 0.1 # Порог для Length Diff
    # Фильтруем по качеству и сортируем по времени
    fast_good_results_by_time = df_results[(df_results['Length Diff'].fillna(float('inf')) < quality_threshold)].sort_values('Speedup vs Pure', ascending=False)
    if not fast_good_results_by_time.empty:
        print(fast_good_results_by_time[['Method', 'Beam Size', 'Elimination %', 'Threshold', 'Avg Length', 'Length Diff', 'Avg Time (s)', 'Speedup vs Pure']].round(4).to_string(index=False))
    else:
        print(f"Нет результатов с Length Diff < {quality_threshold}.")

# ==============================================================================
# 10. ВИЗУАЛИЗАЦИЯ РАСПРЕДЕЛЕНИЯ ДЛИН ТУРОВ (VIOLIN PLOT)
# ==============================================================================
import pandas as pd
import seaborn as sns
import textwrap # Уже импортирован выше, но для ясности
import numpy as np # Убедитесь, что numpy импортирован
import matplotlib.pyplot as plt # Убедитесь, что pyplot импортирован

print("\nGenerating Violin Plot for Tour Length Distribution...")

# --- Подготовка данных для графика ---
plot_data_list_violin = []
# Имена методов, как они хранятся в ключах all_results
method_keys_for_plot = ['pure_bc', 'gnn_bs_v1', 'gnn_bs_v2', 'gnn_bc_v1', 'gnn_bc_v2']
# Красивые имена для отображения на графике
method_display_names = {
    'pure_bc': 'Pure B&C',
    'gnn_bs_v1': 'GNN+BS (v1)',
    'gnn_bs_v2': 'GNN+BS (v2)',
    'gnn_bc_v1': 'GNN+B&C (v1)',
    'gnn_bc_v2': 'GNN+B&C (v2)'
}

# --- Получаем размерности инстансов ---
# Нам нужен способ сопоставить каждую длину из списка с размерностью инстанса.
# Предположим, что порядок инстансов в test_instances соответствует порядку
# результатов в списках tour_lengths/solving_times/tour внутри all_results.
instance_dimensions = [inst.shape[0] for inst in test_instances] # Список размерностей [N1, N2, ...]

if not all_results:
     print("Error: 'all_results' dictionary is empty.")
else:
    # Итерируем по методам
    for method_key in method_keys_for_plot:
        if method_key in all_results and all_results[method_key]:
            method_data = all_results[method_key]
            display_name = method_display_names.get(method_key, method_key)

            # Получаем список длин для этого метода
            lengths_list = method_data.get('tour_lengths', [])

            # Проверяем соответствие длины списков
            if len(lengths_list) != len(instance_dimensions):
                print(f"Warning: Mismatch between number of lengths ({len(lengths_list)}) and instances ({len(instance_dimensions)}) for method '{display_name}'. Skipping this method for violin plot.")
                continue

            # Создаем записи для DataFrame
            for i, length_val in enumerate(lengths_list):
                # Пропускаем невалидные длины (inf)
                if length_val is None or length_val == float('inf'):
                    continue

                dimension_val = instance_dimensions[i] # Берем размерность из списка
                plot_data_list_violin.append({
                    'solver_name': display_name,
                    'dimension': dimension_val,
                    'length': length_val
                })
        else:
             print(f"Warning: No data found for method '{method_key}' in all_results.")


    # --- Создание DataFrame ---
    plot_df_violin = pd.DataFrame(plot_data_list_violin)

    if plot_df_violin.empty:
        print("Error: No valid data available to generate violin plot after processing.")
    else:
        print("\nDataFrame for violin plot prepared:")
        print(plot_df_violin.head())
        print(f"\nDimensions in violin plot data: {sorted(plot_df_violin['dimension'].unique())}")
        print(f"Solvers in violin plot data: {plot_df_violin['solver_name'].unique()}")

        # --- Визуализация Violin Plot ---
        print("\nGenerating violin plot...")

        # Перенос длинных названий методов
        wrap_width = 18
        plot_df_violin['solver_wrapped'] = plot_df_violin['solver_name'].apply(
            lambda x: textwrap.fill(str(x).replace("_", " "), wrap_width)
        )

        # Получаем уникальные размерности и сортируем их
        dimensions_found = sorted(plot_df_violin['dimension'].unique())
        dimensions_found = [d for d in dimensions_found if not np.isnan(d)]

        if not dimensions_found:
             print("Error: No valid dimensions found in the data.")
        else:
            solver_order = plot_df_violin.groupby('solver_wrapped')['length'].median().sort_values().index

            # --- ИЗМЕНЕНИЕ: Уменьшаем height и aspect ---
            num_dims = len(dimensions_found)
            col_wrap_val = min(num_dims, 3)
            # Сделаем высоту меньше, например, фиксированной или меньше зависящей от числа солверов
            plot_height = 6 # Попробуйте фиксированную высоту
            # Аспектное соотношение можно сделать ближе к квадратному или чуть шире
            plot_aspect = 1.2 # Попробуйте значение > 1

            g = sns.FacetGrid(plot_df_violin, col="dimension", col_order=dimensions_found,
                              height=plot_height, # Используем новую высоту
                              aspect=plot_aspect, # Используем новый аспект
                              sharex=False, col_wrap=col_wrap_val)
            # --- КОНЕЦ ИЗМЕНЕНИЯ ---

            # Рисуем скрипичные диаграммы
            g.map_dataframe(sns.violinplot, x="length", y="solver_wrapped",
                            palette="viridis", orient='h', cut=0, inner=None, linewidth=1.0,
                            order=solver_order)

            # Рисуем точки поверх
            g.map_dataframe(sns.stripplot, x="length", y="solver_wrapped",
                            color="black", alpha=0.4, size=3, orient='h', jitter=0.15,
                            order=solver_order)

            # Настройка графиков (остается без изменений)
            g.fig.suptitle('Длины решений', y=1.03, fontsize=16)
            g.set_titles("N = {col_name:.0f}")
            g.set_axis_labels("Длина Найденного Тура", "Алгоритм")
            for ax in g.axes.flat:
                ax.tick_params(axis='y', labelsize=9)

            plt.tight_layout(rect=[0, 0, 1, 0.97])
            plt.show()

            # Создаем FacetGrid
            num_dims = len(dimensions_found)
            col_wrap_val = min(num_dims, 3)
            g = sns.FacetGrid(plot_df_violin, col="dimension", col_order=dimensions_found,
                              height=max(5, len(solver_order) * 0.8), # Адаптируем высоту
                              aspect= max(0.5, 6 / col_wrap_val / max(4, len(solver_order)*0.8) * 5), # Адаптируем аспект
                              sharex=False, col_wrap=col_wrap_val)

            # Рисуем скрипичные диаграммы
            g.map_dataframe(sns.violinplot, x="length", y="solver_wrapped",
                            palette="viridis", orient='h', cut=0, inner=None, linewidth=1.0,
                            order=solver_order) # Используем отсортированный порядок

            # Рисуем точки поверх (stripplot)
            g.map_dataframe(sns.stripplot, x="length", y="solver_wrapped",
                            color="black", alpha=0.4, size=3, orient='h', jitter=0.15,
                            order=solver_order) # Тот же порядок

            # Настройка графиков
            g.fig.suptitle('Длины решений', y=1.03, fontsize=16)
            g.set_titles("N = {col_name:.0f}")
            g.set_axis_labels("Длина Найденного Тура", "Алгоритм")

            # Улучшаем читаемость оси Y
            for ax in g.axes.flat:
                ax.tick_params(axis='y', labelsize=9)
                # Опционально: добавить линии средних/медиан, если нужно

            plt.tight_layout(rect=[0, 0, 1, 0.97])
            plt.show()


In [None]:
# ==============================================================================
# 11. ВИЗУАЛИЗАЦИЯ ВРЕМЯ vs КАЧЕСТВО (SCATTER PLOT)
# ==============================================================================
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from adjustText import adjust_text

print("\nGenerating Time vs Quality Scatter Plot...")

# --- Преобразуем aggregated_results в DataFrame ---
scatter_data_list = []
# Имена методов и ключи (те же, что и раньше)
method_keys_for_plot = ['pure_bc', 'gnn_bs_v1', 'gnn_bs_v2', 'gnn_bc_v1', 'gnn_bc_v2']
method_display_names = {
    'pure_bc': 'Pure B&C', 'gnn_bs_v1': 'GNN+BS (v1)', 'gnn_bs_v2': 'GNN+BS (v2)',
    'gnn_bc_v1': 'GNN+B&C (v1)', 'gnn_bc_v2': 'GNN+B&C (v2)'
}

# Получаем оптимальную длину (если она есть и консистентна)
optimal_len = aggregated_results.get('pure_bc', {}).get('avg_tour_lengths', np.nan)

for method_key in method_keys_for_plot:
    if method_key in aggregated_results:
        res = aggregated_results[method_key]
        avg_len = res.get('avg_tour_lengths', np.nan)
        avg_time = res.get('avg_solving_times', np.nan)
        display_name = method_display_names.get(method_key, method_key)

        # Рассчитываем Gap (%)
        gap = np.nan
        if not np.isnan(optimal_len) and optimal_len > 0 and not np.isnan(avg_len):
            gap = (avg_len - optimal_len) / optimal_len * 100

        scatter_data_list.append({
            'Method': display_name,
            'Average Time (s)': avg_time,
            'Average Length': avg_len,
            'Gap (%)': gap
        })

scatter_df = pd.DataFrame(scatter_data_list)

if scatter_df.empty or scatter_df[['Average Time (s)', 'Gap (%)']].dropna().empty:
    print("Error: Not enough valid data to generate scatter plot.")
else:
    print("\nDataFrame for Scatter Plot:")
    print(scatter_df)

    # --- Построение графика ---
    plt.figure(figsize=(12, 8)) # Можно сделать чуть больше
    ax = plt.gca() # Получаем оси для передачи в adjust_text

    # Используем seaborn scatterplot
    scatter_plot = sns.scatterplot(
        data=scatter_df,
        x='Average Time (s)',
        y='Gap (%)',
        hue='Method',
        style='Method',
        s=220,
        palette='viridis',
        legend='full', # Оставим легенду на случай, если adjust_text не сработает
        ax=ax # Указываем оси
    )

    texts = []
    for i in range(scatter_df.shape[0]):
         row = scatter_df.iloc[i]
         # Проверяем, что координаты валидны перед созданием текста
         if not pd.isna(row['Average Time (s)']) and not pd.isna(row['Gap (%)']):
             texts.append(ax.text(row['Average Time (s)'], row['Gap (%)'], row['Method'], fontsize=11))

    # Вызываем adjust_text для автоматического размещения
    if texts:
        adjust_text(texts, ax=ax, # Передаем оси
                    # Опции для настройки расталкивания:
                    expand_points=(1.2, 1.2), # Увеличить расстояние от точек
                    force_points=(0.2, 0.2), # Сила отталкивания от точек
                    force_text=(0.3, 0.5),    # Сила отталкивания текстов друг от друга
                    arrowprops=dict(arrowstyle="-", color='gray', lw=0.5, alpha=0.7) # Стрелки к точкам
                   )
    else:
        print("No valid points found to label.")

    # Настройка осей и заголовка
    plt.title('Компромисс Время-Качество для Разных Методов TSP', fontsize=16)
    plt.xlabel('Среднее Время Решения (s)')
    plt.ylabel('Среднее Отклонение от Оптимума (%)')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.axhline(0, color='grey', linestyle=':', linewidth=1)

    # Опционально: логарифмическая шкала
    # plt.xscale('log')
    # plt.xlabel('Среднее Время Решения (s, log scale)')

    # Перемещаем легенду seaborn, если она перекрывает что-то
    handles, labels = ax.get_legend_handles_labels()
    if handles: # Проверяем, есть ли что показывать в легенде
         # Убираем 'Method' из заголовка легенды, если он есть
         if labels[0].lower() == 'method':
              handles = handles[1:]
              labels = labels[1:]
         ax.legend(handles=handles, labels=labels, title="Методы", bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0.)

    plt.tight_layout(rect=[0, 0, 0.85, 1]) # Оставляем место справа для легенды
    plt.show()


In [None]:
# 3. Generate test instances
test_instances = generate_test_instances(num_instances=5, num_nodes=20) # Можете изменить количество

# --- ЗАПУСК ТЕСТОВ С РАЗНЫМИ СТРАТЕГИЯМИ GNN ---

# --- Тест 1: GNN с отключенной фиксацией ---
print("\n===== ТЕСТ 1: GNN-guided B&C (Fixing OFF, Elim ON) =====")
gnn_solver_elim_only = GNNBranchCutSolver(
    gnn_model=net,
    threshold=0.6,          # Порог для фиксации здесь не используется
    fixing_percentage=0.0,  # <--- ОТКЛЮЧАЕМ ФИКСАЦИЮ
    elimination_percentage=0.2 # Оставляем удаление 20%
)
print("Comparing Pure B&C with GNN-guided (Elimination Only)...")
results_elim_only = TSPComparison.compare_methods(gnn_solver_elim_only, test_instances)

print("\nResults Summary (Elimination Only):")
print(f"Pure Branch & Cut - Avg Tour Length: {results_elim_only['pure_bc']['avg_tour_lengths']:.4f}, Avg Time: {results_elim_only['pure_bc']['avg_solving_times']:.4f}s")
print(f"GNN (Elim Only) - Avg Tour Length: {results_elim_only['gnn_bc']['avg_tour_lengths']:.4f}, Avg Time: {results_elim_only['gnn_bc']['avg_solving_times']:.4f}s")
print("---------------------------------------------------\n")


# --- Тест 2: GNN с очень осторожной фиксацией ---
print("\n===== ТЕСТ 2: GNN-guided B&C (Conservative Fixing, Elim ON) =====")
gnn_solver_conservative_fix = GNNBranchCutSolver(
    gnn_model=net,
    threshold=0.95,         # <--- ОЧЕНЬ ВЫСОКИЙ ПОРОГ
    fixing_percentage=0.1,  # <--- ФИКСИРУЕМ МЕНЬШИЙ ПРОЦЕНТ (только самые уверенные)
    elimination_percentage=0.2 # Оставляем удаление 20%
)
print("Comparing Pure B&C with GNN-guided (Conservative Fixing)...")
# Для этого теста можно включить debug_level=1, чтобы увидеть, сколько ребер реально фиксируется
results_conservative_fix = TSPComparison.compare_methods(gnn_solver_conservative_fix, test_instances)

print("\nResults Summary (Conservative Fixing):")
print(f"Pure Branch & Cut - Avg Tour Length: {results_conservative_fix['pure_bc']['avg_tour_lengths']:.4f}, Avg Time: {results_conservative_fix['pure_bc']['avg_solving_times']:.4f}s")
print(f"GNN (Cons. Fix) - Avg Tour Length: {results_conservative_fix['gnn_bc']['avg_tour_lengths']:.4f}, Avg Time: {results_conservative_fix['gnn_bc']['avg_solving_times']:.4f}s")
print("---------------------------------------------------\n")


# --- Тест 3: GNN с отключенным удалением ---
print("\n===== ТЕСТ 3: GNN-guided B&C (Fixing ON, Elim OFF) =====")
gnn_solver_fix_only = GNNBranchCutSolver(
    gnn_model=net,
    threshold=0.6,          # Возвращаем исходный порог
    fixing_percentage=0.2,  # Возвращаем исходный процент
    elimination_percentage=0.0 # <--- ОТКЛЮЧАЕМ УДАЛЕНИЕ
)
print("Comparing Pure B&C with GNN-guided (Fixing Only)...")
# Включаем debug_level=1 для анализа фиксации
results_fix_only = TSPComparison.compare_methods(gnn_solver_fix_only, test_instances)

print("\nResults Summary (Fixing Only):")
print(f"Pure Branch & Cut - Avg Tour Length: {results_fix_only['pure_bc']['avg_tour_lengths']:.4f}, Avg Time: {results_fix_only['pure_bc']['avg_solving_times']:.4f}s")
print(f"GNN (Fix Only)  - Avg Tour Length: {results_fix_only['gnn_bc']['avg_tour_lengths']:.4f}, Avg Time: {results_fix_only['gnn_bc']['avg_solving_times']:.4f}s")
print("---------------------------------------------------\n")


# --- (Опционально) Исходный вариант для сравнения ---
# print("\n===== ИСХОДНЫЙ ВАРИАНТ GNN-guided B&C =====")
# gnn_solver_original = GNNBranchCutSolver(
#     gnn_model=net,
#     threshold=0.6,
#     fixing_percentage=0.2,
#     elimination_percentage=0.2
# )
# print("Comparing Pure B&C with Original GNN-guided...")
# results_original = TSPComparison.compare_methods(gnn_solver_original, test_instances, debug_level=0)
#
# print("\nResults Summary (Original):")
# print(f"Pure Branch & Cut - Avg Tour Length: {results_original['pure_bc']['avg_tour_lengths']:.4f}, Avg Time: {results_original['pure_bc']['avg_solving_times']:.4f}s")
# print(f"GNN (Original)  - Avg Tour Length: {results_original['gnn_bc']['avg_tour_lengths']:.4f}, Avg Time: {results_original['gnn_bc']['avg_solving_times']:.4f}s")
# print("---------------------------------------------------\n")


# --- Анализ и выводы по результатам тестов ---
print("\n===== ОБЩИЕ ВЫВОДЫ ПО ТЕСТАМ =====")
# Сравниваем результаты results_elim_only, results_conservative_fix, results_fix_only
# (и results_original, если запускали) с результатами Pure B&C из любого из них (они должны быть одинаковы)

pure_avg_len = results_elim_only['pure_bc']['avg_tour_lengths']
pure_avg_time = results_elim_only['pure_bc']['avg_solving_times']

print(f"Pure B&C:           Avg Len={pure_avg_len:.4f}, Avg Time={pure_avg_time:.4f}s")

elim_only_avg_len = results_elim_only['gnn_bc']['avg_tour_lengths']
elim_only_avg_time = results_elim_only['gnn_bc']['avg_solving_times']
print(f"GNN (Elim Only):    Avg Len={elim_only_avg_len:.4f} (Diff: {elim_only_avg_len-pure_avg_len:+.4f}), Avg Time={elim_only_avg_time:.4f}s (Speedup: {pure_avg_time/elim_only_avg_time:.2f}x)")

cons_fix_avg_len = results_conservative_fix['gnn_bc']['avg_tour_lengths']
cons_fix_avg_time = results_conservative_fix['gnn_bc']['avg_solving_times']
print(f"GNN (Conserv. Fix): Avg Len={cons_fix_avg_len:.4f} (Diff: {cons_fix_avg_len-pure_avg_len:+.4f}), Avg Time={cons_fix_avg_time:.4f}s (Speedup: {pure_avg_time/cons_fix_avg_time:.2f}x)")

fix_only_avg_len = results_fix_only['gnn_bc']['avg_tour_lengths']
fix_only_avg_time = results_fix_only['gnn_bc']['avg_solving_times']
print(f"GNN (Fix Only):     Avg Len={fix_only_avg_len:.4f} (Diff: {fix_only_avg_len-pure_avg_len:+.4f}), Avg Time={fix_only_avg_time:.4f}s (Speedup: {pure_avg_time/fix_only_avg_time:.2f}x)")

# Дополнительные выводы на основе сравнения
if elim_only_avg_len <= pure_avg_len and elim_only_avg_time < pure_avg_time:
    print("\n-> Стратегия 'Только Удаление' выглядит перспективной: сохраняет качество и ускоряет.")
elif cons_fix_avg_len <= pure_avg_len and cons_fix_avg_time < pure_avg_time:
     print("\n-> Стратегия 'Консервативная Фиксация' выглядит перспективной: сохраняет качество и ускоряет.")
elif fix_only_avg_len != float('inf') and fix_only_avg_len > pure_avg_len:
     print("\n-> Стратегия 'Только Фиксация' (даже без удаления) приводит к субоптимальным решениям. Фиксация - основная проблема.")
elif elim_only_avg_len > pure_avg_len:
     print("\n-> Стратегия 'Только Удаление' ухудшает качество. Возможно, удаляются нужные ребра или GNN плохо их ранжирует.")

print("\nРекомендация: Проанализируйте подробные результаты тестов. Если 'Только Удаление' или 'Консервативная Фиксация' показывают хорошие результаты, используйте их. Если нет - основное внимание на улучшение GNN модели.")

# 8. Visualize comparison (можно выбрать лучшие результаты для визуализации)
# Например, сравнить Pure B&C и GNN (Elim Only)
# TSPComparison.plot_comparison_results(results_elim_only)

# 9. Detailed visualization of one instance (сравнить Pure и лучший GNN вариант)
# instance_idx = 0
# tour_pure, tour_length_pure, _ , _= TSPComparison.pure_branch_cut_solve(test_instances[instance_idx])
# tour_gnn_best, tour_length_gnn_best, _ = gnn_solver_elim_only.solve_tsp(test_instances[instance_idx]) # Пример
# TSPComparison.visualize_tours(...)