# SET UP

In [None]:
!pip uninstall numpy -y

In [None]:
!pip install numpy==1.26.4

In [None]:
import numpy
print(numpy.__version__)

In [None]:
import torch

def check_cuda_streams():
    if not torch.cuda.is_available():
        print("CUDA is not available.")
    else:
        current_stream = torch.cuda.current_stream()
        print(f"Current CUDA stream: {current_stream}")

check_cuda_streams()

In [None]:
!nvidia-smi

In [None]:
import psutil

def check_cpu_info():
    # Get the number of physical cores
    physical_cores = psutil.cpu_count(logical=False)
    # Get the number of logical (including hyperthreaded) cores
    logical_cores = psutil.cpu_count(logical=True)

    print(f"Number of Physical Cores: {physical_cores}")
    print(f"Number of Logical Cores (Threads): {logical_cores}")

    # Get CPU frequency
    cpu_freq = psutil.cpu_freq()
    print(f"CPU Frequency: {cpu_freq.current:.2f} MHz")

    # Get CPU utilization
    cpu_usage = psutil.cpu_percent(interval=1)
    print(f"CPU Utilization: {cpu_usage:.2f}%")

check_cpu_info()

In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()

In [None]:
!conda update -n base -c defaults conda
!conda install -n base -c conda-forge mamba

In [None]:
!which conda

In [None]:
!conda install -n base -c conda-forge ecole

In [None]:
!conda install -n base -c conda-forge scip=8.0

In [None]:
import numpy
print(numpy.__version__)

In [None]:
import ecole
print(dir(ecole))

In [None]:
!pip install torch torchvision torchaudio

In [None]:
!pip install torch-geometric

In [None]:
import torch
print(torch.__version__)         
print(torch.version.cuda)        
print(torch.backends.cudnn.version())
print(torch.cuda.is_available()) 

In [None]:
!pip list | grep torch

In [None]:
!pip install scipy

In [None]:
!pip install pyscipopt

# Drive Connection

In [None]:
from google.colab import drive
import time
import os
drive.flush_and_unmount()
# Mount Google Drive if needed
drive.mount('/content/drive')

test_file = "/content/drive/MyDrive/drive_speed_test.bin"
file_size = 1024 * 1024 * 100  # 100 MB

# Write Speed Test
start_time = time.time()
with open(test_file, "wb") as f:
    f.write(os.urandom(file_size))
write_time = time.time() - start_time
write_speed = file_size / write_time / (1024 * 1024)  # MB/s
print(f"Write Speed: {write_speed:.2f} MB/s")

# Read Speed Test
start_time = time.time()
with open(test_file, "rb") as f:
    f.read()
read_time = time.time() - start_time
read_speed = file_size / read_time / (1024 * 1024)  # MB/s
print(f"Read Speed: {read_speed:.2f} MB/s")

# Clean up
os.remove(test_file)

In [None]:
os.chdir('/content/drive/My Drive/Thesis')

In [None]:
print(os.getcwd())

In [None]:
!ls

# Utilities

In [None]:

import gzip
import pickle
import datetime
import numpy as np
import torch
import torch.nn.functional as F
import torch_geometric

def log(str, logfile=None):
    str = f'[{datetime.datetime.now()}] {str}'
    print(str)
    if logfile is not None:
        with open(logfile, mode='a') as f:
            print(str, file=f)


def pad_tensor(input_, pad_sizes, pad_value=-1e8):
    max_pad_size = pad_sizes.max()
    output = input_.split(pad_sizes.cpu().numpy().tolist())
    output = torch.stack([F.pad(slice_, (0, max_pad_size-slice_.size(0)), 'constant', pad_value)
                          for slice_ in output], dim=0)
    return output



class BipartiteNodeData(torch_geometric.data.Data):
    def __init__(self, constraint_features, edge_indices, edge_features, variable_features,
                 candidates, nb_candidates, candidate_choice, candidate_scores):
        super().__init__()
        self.constraint_features = constraint_features
        self.edge_index = edge_indices
        self.edge_attr = edge_features
        self.variable_features = variable_features
        self.candidates = candidates
        self.nb_candidates = nb_candidates
        self.candidate_choices = candidate_choice
        self.candidate_scores = candidate_scores

    def __inc__(self, key, value, store, *args, **kwargs):
        if key == 'edge_index':
            return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]])
        elif key == 'candidates':
            return self.variable_features.size(0)
        else:
            return super().__inc__(key, value, *args, **kwargs)


class GraphDataset(torch_geometric.data.Dataset):
    def __init__(self, sample_files):
        super().__init__(root=None, transform=None, pre_transform=None)
        self.sample_files = sample_files

    def len(self):
        return len(self.sample_files)

    def get(self, index):
        with gzip.open(self.sample_files[index], 'rb') as f:
            sample = pickle.load(f)

        sample_observation, sample_action, sample_action_set, sample_scores = sample['data']

        constraint_features, (edge_indices, edge_features), variable_features = sample_observation
        constraint_features = torch.FloatTensor(constraint_features)
        edge_indices = torch.LongTensor(edge_indices.astype(np.int32))
        edge_features = torch.FloatTensor(np.expand_dims(edge_features, axis=-1))
        variable_features = torch.FloatTensor(variable_features)

        candidates = torch.LongTensor(np.array(sample_action_set, dtype=np.int32))
        candidate_choice = torch.where(candidates == sample_action)[0][0]  # action index relative to candidates
        candidate_scores = torch.FloatTensor([sample_scores[j] for j in candidates])

        graph = BipartiteNodeData(constraint_features, edge_indices, edge_features, variable_features,
                                  candidates, len(candidates), candidate_choice, candidate_scores)
        graph.num_nodes = constraint_features.shape[0]+variable_features.shape[0]
        return graph



class Scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):
    def __init__(self, optimizer, **kwargs):
        super().__init__(optimizer, **kwargs)

    def step(self, metrics):
        # convert `metrics` to float, in case it's a zero-dim Tensor
        current = float(metrics)
        self.last_epoch =+1

        if self.is_better(current, self.best):
            self.best = current
            self.num_bad_epochs = 0
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs == self.patience:
            self._reduce_lr(self.last_epoch)

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

# Models

## Previous Models and Code References

**Original GCN (2019)**  
Based on the NeurIPS paper *Exact Combinatorial Optimization with Graph Convolutional Neural Networks*  
(Gasse, M., Chételat, D., Ferroni, N., Charlin, L., & Lodi, A., 2019).  
Code available at: [ds4dm/learn2branch](https://github.com/ds4dm/learn2branch)

**Attention GCN (Attention Mechanism)**  
*BiGNN: Bipartite Graph Neural Network with Attention Mechanism for Solving Multiple Traveling Salesman Problems in Urban Logistics*  
(Liang, H., Wang, S., & Li, H., n.d.).  
Code available at: [CO-RL/DeepMTSP](https://github.com/CO-RL/DeepMTSP)


## MAGCN

In [None]:
from torch_geometric.nn import MessagePassing
import torch_geometric

class PreNormException(Exception):
    """Exception to signal completion of PreNormLayer pre-training."""
    pass


class PreNormLayer(torch.nn.Module):
    def __init__(self, n_units, shift=True, scale=True, name=None):
        super().__init__()
        assert shift or scale
        self.register_buffer('shift', torch.zeros(n_units) if shift else None)
        self.register_buffer('scale', torch.ones(n_units) if scale else None)
        self.n_units = n_units
        self.waiting_updates = False
        self.received_updates = False

    def forward(self, input_):
        if self.waiting_updates:
            self.update_stats(input_)
            self.received_updates = True
            raise PreNormException

        if self.shift is not None:
            input_ = input_ + self.shift

        if self.scale is not None:
            input_ = input_ * self.scale

        return input_

    def start_updates(self):
        self.avg = 0
        self.var = 0
        self.m2 = 0
        self.count = 0
        self.waiting_updates = True
        self.received_updates = False

    def update_stats(self, input_):
        """
        Online mean and variance estimation. See: Chan et al. (1979) Updating
        Formulae and a Pairwise Algorithm for Computing Sample Variances.
        https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
        """
        assert self.n_units == 1 or input_.shape[-1] == self.n_units, f"Expected input dimension of size {self.n_units}, got {input_.shape[-1]}."

        input_ = input_.reshape(-1, self.n_units)
        sample_avg = input_.mean(dim=0)
        sample_var = (input_ - sample_avg).pow(2).mean(dim=0)
        sample_count = np.prod(input_.size())/self.n_units

        delta = sample_avg - self.avg

        self.m2 = self.var * self.count + sample_var * sample_count + delta ** 2 * self.count * sample_count / (
                self.count + sample_count)

        self.count += sample_count
        self.avg += delta * sample_count / self.count
        self.var = self.m2 / self.count if self.count > 0 else 1

    def stop_updates(self):
        """
        Ends pre-training for that layer, and fixes the layers's parameters.
        """
        assert self.count > 0
        if self.shift is not None:
            self.shift = -self.avg

        if self.scale is not None:
            self.var[self.var < 1e-8] = 1
            self.scale = 1 / torch.sqrt(self.var)

        del self.avg, self.var, self.m2, self.count
        self.waiting_updates = False
        self.trainable = False


##############################################################################
# Aggregation Functions
##############################################################################
def sum_aggregate(inputs, index, dim_size=None):
    """
    Sum aggregator using torch.index_add_.
    inputs : [E, D]  (E edges, D embedding dim)
    index  : [E]     (which node each edge belongs to)
    """
    if dim_size is None:
        dim_size = int(index.max()) + 1
    out = torch.zeros(dim_size, inputs.size(1), device=inputs.device)
    out.index_add_(0, index, inputs)
    return out


def mean_aggregate(inputs, index, dim_size=None):
    """
    Mean aggregator using sum_aggregate and torch.bincount.
    """
    if dim_size is None:
        dim_size = int(index.max()) + 1
    sums = sum_aggregate(inputs, index, dim_size)
    counts = torch.bincount(index, minlength=dim_size).float().unsqueeze(1)
    counts = counts.clamp_min(1)  # Prevent division by zero
    mean = sums / counts
    return mean


def max_aggregate(inputs, index, dim_size=None):
    """
    Max aggregator using torch.scatter_reduce.
    """
    if dim_size is None:
        dim_size = int(index.max()) + 1
    # Initialize with very small values
    out = torch.full((dim_size, inputs.size(1)), float('-inf'), device=inputs.device)
    # Perform scatter reduce with 'amax' (max) operation
    out = out.scatter_reduce(0, index.unsqueeze(-1).expand_as(inputs), inputs, reduce='amax', include_self=True)
    return out

def aggregate_var(inputs, index, dim_size=None):
    """
    Computes the variance of the neighboring node features.

    Args:
        inputs (torch.Tensor): Input features [E, D], where E is the number of edges.
        index (torch.Tensor): Index tensor [E], indicating the target node for each edge.
        dim_size (int, optional): Total number of nodes. Defaults to None.

    Returns:
        torch.Tensor: Variance aggregated features [N, D], where N is the number of nodes.
    """
    if dim_size is None:
        dim_size = int(index.max()) + 1

    # Compute D^{-1} A X^2
    X_squared = inputs.pow(2)  # [E, D]
    mean_agg_X2 = mean_aggregate(X_squared, index, dim_size)  # [N, D]

    # Compute (D^{-1} A X)^2
    mean_agg = mean_aggregate(inputs, index, dim_size)  # [N, D]
    X_mean_sq = mean_agg.pow(2)  # [N, D]

    # Compute variance: ReLU(D^{-1} A X^2 - (D^{-1} A X)^2)
    var = F.relu(mean_agg_X2 - X_mean_sq)  # [N, D]

    return var

def aggregate_std(inputs, index, dim_size=None):
    """
    Computes the standard deviation of the neighboring node features.

    Args:
        inputs (torch.Tensor): Input features [E, D], where E is the number of edges.
        index (torch.Tensor): Index tensor [E], indicating the target node for each edge.
        dim_size (int, optional): Total number of nodes. Defaults to None.

    Returns:
        torch.Tensor: Standard deviation aggregated features [N, D], where N is the number of nodes.
    """
    var = aggregate_var(inputs, index, dim_size)  # [N, D]
    std = torch.sqrt(var + 1e-8)  # [N, D]
    return std



##############################################################################
# MultiAggregatorBipartiteConvManual with PreNormLayers
##############################################################################
class MultiAggregatorBipartiteConvManual_PNA(MessagePassing):
    def __init__(self, emb_size=64, avg_d=1.0):
        super().__init__(aggr=None)  # We'll manually aggregate
        self.emb_size = emb_size
        self.avg_d = avg_d
        # Linear transforms for node/edge embeddings
        self.linear_left = nn.Linear(emb_size, emb_size, bias=False)
        self.linear_right = nn.Linear(emb_size, emb_size, bias=False)
        self.linear_edge = nn.Linear(1, emb_size, bias=False)

        # PreNormLayer for feature_module_final
        self.feature_module_final = nn.Sequential(
            PreNormLayer(1, shift=False),
            nn.ReLU(),
            nn.Linear(emb_size, emb_size)
        )

        # PreNormLayer for post_conv_module
        self.post_conv_module = nn.Sequential(
            PreNormLayer(1, shift=False)
        )

        # self.highway_mlp = nn.Sequential(
        #     nn.Linear(4 * emb_size, emb_size),
        #     nn.ReLU(),
        #     nn.Dropout(0.1),
        #     nn.Linear(emb_size, emb_size)
        # )


        self.out_mlp = nn.Sequential(
            nn.Linear(12 * emb_size, 3 * emb_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(3 * emb_size, emb_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(emb_size, emb_size)
        )

        # Output module (newly added for concatenation)
        self.output_module = nn.Sequential(
            nn.Linear(2 * emb_size, emb_size),
            nn.ReLU(),
            nn.Linear(emb_size, emb_size)
        )

    def forward(self, left_feats, edge_index, edge_feats, right_feats):
        """
        left_feats : [N_left, emb_size]
        right_feats: [N_right, emb_size]
        edge_feats : [E, 1]
        edge_index : [2, E]
        """
        # Apply feature_module_final PreNormLayer
        # processed_right_feats = self.feature_module_final(right_feats)

        # Perform message passing
        aggr_out = self.propagate(
            edge_index=edge_index,
            size=(left_feats.size(0), right_feats.size(0)),
            node_features=(left_feats, right_feats),
            edge_features=edge_feats
        )

        # Concatenate with original right_feats
        combined = torch.cat([self.post_conv_module(aggr_out), right_feats], dim=-1)

        # Pass through output_module
        return self.output_module(combined)

    def message(self, node_features_i, node_features_j, edge_features):
        """
        node_features_i, node_features_j: [E, emb_size]
        edge_features: [E, 1]
        """
        fi = self.linear_left(node_features_i)
        fj = self.linear_right(node_features_j)
        fe = self.linear_edge(edge_features)

        # Sum the transformed features
        combined = fi + fj + fe

        # Apply feature_module_final within message
        output = self.feature_module_final(combined)  # Normalize and transform per message

        return output

    def aggregate(self, inputs, index, dim_size=None):


        if dim_size is None:
            dim_size = int(index.max()) + 1

        device = inputs.device

        # Compute node degrees
        degrees = torch.bincount(index, minlength=dim_size).float().unsqueeze(1)  # [N_nodes, 1]
        degrees = degrees.clamp_min(1.0)  # Prevent division by zero

        # Perform aggregations
        mean_agg = mean_aggregate(inputs, index, dim_size)    # [N_nodes, emb_size]
        sum_agg = sum_aggregate(inputs, index, dim_size)      # [N_nodes, emb_size]
        max_agg = max_aggregate(inputs, index, dim_size)      # [N_nodes, emb_size]
        # var_agg = aggregate_var(inputs, index, dim_size)  # [N_nodes, emb_size]
        std_agg = aggregate_std(inputs, index, dim_size)  # [N_nodes, emb_size]

        # Apply degree scalers: Identity, Amplification, Attenuation
        # For Mean Aggregation
        scaled_mean_identity = mean_agg  # Identity
        scaled_mean_amplification = (torch.log(degrees + 1) / self.avg_d) * mean_agg  # Amplification
        scaled_mean_attenuation = (self.avg_d / torch.log(degrees + 1)) * mean_agg    # Attenuation

        # For Sum Aggregation
        scaled_sum_identity = sum_agg  # Identity
        scaled_sum_amplification = (torch.log(degrees + 1) / self.avg_d) * sum_agg    # Amplification
        scaled_sum_attenuation = (self.avg_d / torch.log(degrees + 1)) * sum_agg      # Attenuation

        # For Max Aggregation
        scaled_max_identity = max_agg  # Identity
        scaled_max_amplification = (torch.log(degrees + 1) / self.avg_d) * max_agg    # Amplification
        scaled_max_attenuation = (self.avg_d / torch.log(degrees + 1)) * max_agg      # Attenuation

        scaled_std_identity = std_agg  # Identity
        scaled_std_amplification = (torch.log(degrees + 1) / self.avg_d) * std_agg  # Amplification
        scaled_std_attenuation = (self.avg_d / torch.log(degrees + 1)) * std_agg    # Attenuation

        # Identity aggregations processed through highway MLP
        identity_agg = torch.cat([
            scaled_mean_identity,
            scaled_sum_identity,
            scaled_max_identity,
            scaled_std_identity
        ], dim=-1)

        # Concatenate identity aggregations for Out MLP
        cat_agg_out = torch.cat([
            scaled_mean_identity,
            scaled_sum_identity,
            scaled_max_identity,
            scaled_std_identity,
            scaled_mean_amplification,
            scaled_sum_amplification,
            scaled_max_amplification,
            scaled_std_amplification,
            scaled_mean_attenuation,
            scaled_sum_attenuation,
            scaled_max_attenuation,
            scaled_std_attenuation
        ], dim=-1)  # [N_nodes, 3 * emb_size]

        # # Pass through gate_mlp or further processing as needed
        # gating_logits = self.gate_mlp(cat_agg_gate)               # [N_nodes, num_gates]
        # gating_weights = F.softmax(gating_logits, dim=-1)    # [N_nodes, num_gates]

        # # Weighted sum across [mean, max, sum]
        # stacked = torch.stack([mean_agg, max_agg, sum_agg], dim=1)  # [N_nodes, 3, emb_size]
        # fused = (stacked * gating_weights.unsqueeze(-1)).sum(dim=1)  # [N_nodes, emb_size]

        # skip_output = self.highway_mlp(identity_agg)


        out = self.out_mlp(cat_agg_out)  # [N_nodes, emb_size]

        # Add skip_output to out to get mid_output
        mid_output = out + scaled_sum_identity  # [N_nodes, emb_size]

        # Combine them (residual style)
        # combined = out  # [N_nodes, emb_size]

        # Apply post_conv_module PreNormLayer
        combined = self.post_conv_module(mid_output)

        return combined

    def update(self, aggr_out):
        return aggr_out

#############################################################################
## BaseModel
#############################################################################
class BaseModel(torch.nn.Module):
    """
    Our base model class, which implements pre-training methods.
    """
    def pre_train_init(self):
        for module in self.modules():
            if isinstance(module, PreNormLayer):
                module.start_updates()

    def pre_train_next(self):
        for module in self.modules():
            if isinstance(module, PreNormLayer) and module.waiting_updates and module.received_updates:
                module.stop_updates()
                return module
        return None

    def pre_train(self, *args, **kwargs):
        try:
            with torch.no_grad():
                self.forward(*args, **kwargs)
            return False
        except PreNormException:
            return True


class GNNPolicy(BaseModel):
    def __init__(self, cons_nfeats=5, var_nfeats=19, edge_nfeats=1,
                 emb_size=64, avg_d_left=1.0, avg_d_right=1.0):
        super().__init__()
        self.avg_d_left = avg_d_left
        self.avg_d_right = avg_d_right

        # Constraint embedding
        self.cons_embedding = nn.Sequential(
            PreNormLayer(cons_nfeats),
            nn.Linear(cons_nfeats, emb_size),
            nn.ReLU(),
            nn.Linear(emb_size, emb_size),
            nn.ReLU(),
        )

        # Variable embedding
        self.var_embedding = nn.Sequential(
            PreNormLayer(var_nfeats),
            nn.Linear(var_nfeats, emb_size),
            nn.ReLU(),
            nn.Linear(emb_size, emb_size),
            nn.ReLU(),
        )

        # Edge embedding
        self.edge_embedding = nn.Sequential(
            PreNormLayer(edge_nfeats)
        )

        self.v_to_c_layer = MultiAggregatorBipartiteConvManual_PNA(
            emb_size=emb_size,
            avg_d=self.avg_d_left
        )
        self.c_to_v_layer = MultiAggregatorBipartiteConvManual_PNA(
            emb_size=emb_size,
            avg_d=self.avg_d_right
        )


        # Self-attention layers (Q, K, V)
        self.Q = torch.nn.Linear(emb_size, emb_size, bias=True)
        self.K = torch.nn.Linear(emb_size, emb_size, bias=True)
        self.V = torch.nn.Linear(emb_size, emb_size, bias=True)

        # Output head with PreNormLayers if needed (optional)
        self.output_module = nn.Sequential(
            nn.Linear(emb_size, emb_size),
            nn.ReLU(),
            nn.Linear(emb_size, 1, bias=False),
        )

    def forward(self, constraint_feats, edge_index, edge_feats, variable_feats):
        """
        constraint_feats : [N_constraints, cons_nfeats]
        variable_feats   : [N_vars, var_nfeats]
        edge_feats       : [E, edge_nfeats]
        edge_index       : [2, E] => [constraint_idx, variable_idx]
        """
        # 1) Initial embeddings
        c_emb = self.cons_embedding(constraint_feats)  # [N_constraints, emb_size]
        v_emb = self.var_embedding(variable_feats)      # [N_vars, emb_size]
        e_emb = self.edge_embedding(edge_feats)         # [E, 1]

        # 2) Reverse edge index for var->cons message passing
        rev_edge_index = torch.stack([edge_index[1], edge_index[0]], dim=0)

        # 3) Single bipartite pass
        #    (a) variable -> constraint
        c_new = self.v_to_c_layer(
            left_feats=v_emb,
            edge_index=rev_edge_index,
            edge_feats=e_emb,
            right_feats=c_emb
        )

        #    (b) constraint -> variable
        variable_features = self.c_to_v_layer(
            left_feats=c_emb,
            edge_index=edge_index,
            edge_feats=e_emb,
            right_feats=v_emb
        )


        # Step 3: Self-attention block (Exact replication of TensorFlow approach)
        Q = self.Q(variable_features)    # [num_vars, emb_size]
        K = self.K(variable_features)    # [num_vars, emb_size]
        V = self.V(variable_features)    # [num_vars, emb_size]

        # Compute attention scores: [emb_size, emb_size]
        # Equivalent to tf.matmul(tf.transpose(Q), K)/8
        attention_scores = torch.matmul(Q.transpose(0, 1), K) / 8.0  # [emb_size, emb_size]

        # Apply softmax to get attention weights
        attention_weights = F.softmax(attention_scores, dim=-1)  # [emb_size, emb_size]

        # Compute the weighted sum of values: [num_vars, emb_size]
        # Equivalent to tf.matmul(V, attention)
        attended_features = torch.matmul(V, attention_weights)  # [num_vars, emb_size]

        # Apply activation
        attended_features = F.relu(attended_features)

        # Step 4: Final output
        output = self.output_module(attended_features).squeeze(-1)  # [num_vars]


        # # 4) Score each variable node
        # logits = self.output_module(v_new).squeeze(-1)  # [N_vars]

        return output

# Generate Dataset

In [None]:
import os
import glob
import gzip
import argparse
import pickle
import queue
import shutil
import threading
import numpy as np
import ecole
from collections import namedtuple


In [None]:
class ExploreThenStrongBranch:
    def __init__(self, expert_probability):
        self.expert_probability = expert_probability
        self.pseudocosts_function = ecole.observation.Pseudocosts()
        self.strong_branching_function = ecole.observation.StrongBranchingScores()

    def before_reset(self, model):
        self.pseudocosts_function.before_reset(model)
        self.strong_branching_function.before_reset(model)

    def extract(self, model, done):
        probabilities = [1-self.expert_probability, self.expert_probability]
        expert_chosen = bool(np.random.choice(np.arange(2), p=probabilities))
        if expert_chosen:
            return (self.strong_branching_function.extract(model,done), True)
        else:
            return (self.pseudocosts_function.extract(model,done), False)


def send_orders(orders_queue, instances, seed, query_expert_prob, time_limit, out_dir, stop_flag):
    """
    Continuously send sampling orders to workers (relies on limited
    queue capacity).

    Parameters
    ----------
    orders_queue : queue.Queue
        Queue to which to send orders.
    instances : list
        Instance file names from which to sample episodes.
    seed : int
        Random seed for reproducibility.
    query_expert_prob : float in [0, 1]
        Probability of running the expert strategy and collecting samples.
    time_limit : float in [0, 1e+20]
        Maximum running time for an episode, in seconds.
    out_dir: str
        Output directory in which to write samples.
    stop_flag: threading.Event
        A flag to tell the thread to stop.
    """
    rng = np.random.RandomState(seed)

    episode = 0
    while not stop_flag.is_set():
        instance = rng.choice(instances)
        seed = rng.randint(2**32)
        orders_queue.put([episode, instance, seed, query_expert_prob, time_limit, out_dir])
        episode += 1


def make_samples(in_queue, out_queue, stop_flag):
    """
    Worker loop: fetch an instance, run an episode and record samples.
    Parameters
    ----------
    in_queue : queue.Queue
        Input queue from which orders are received.
    out_queue : queue.Queue
        Output queue in which to send samples.
    stop_flag: threading.Event
        A flag to tell the thread to stop.
    """
    sample_counter = 0
    while not stop_flag.is_set():
        episode, instance, seed, query_expert_prob, time_limit, out_dir = in_queue.get()

        scip_parameters = {'separating/maxrounds': 0, 'presolving/maxrestarts': 0,
                           'limits/time': time_limit, 'timing/clocktype': 2}
        observation_function = { "scores": ExploreThenStrongBranch(expert_probability=query_expert_prob),
                                 "node_observation": ecole.observation.NodeBipartite() }
        env = ecole.environment.Branching(observation_function=observation_function,
                                          scip_params=scip_parameters, pseudo_candidates=True)

        print(f"[w {threading.current_thread().name}] episode {episode}, seed {seed}, "
              f"processing instance '{instance}'...\n", end='')
        out_queue.put({
            'type': 'start',
            'episode': episode,
            'instance': instance,
            'seed': seed,
        })

        env.seed(seed)
        observation, action_set, _, done, _ = env.reset(instance)
        while not done:
            scores, scores_are_expert = observation["scores"]
            node_observation = observation["node_observation"]
            node_observation = (node_observation.row_features,
                                (node_observation.edge_features.indices,
                                 node_observation.edge_features.values),
                                node_observation.variable_features)

            action = action_set[scores[action_set].argmax()]

            if scores_are_expert and not stop_flag.is_set():
                data = [node_observation, action, action_set, scores]
                filename = f'{out_dir}/sample_{episode}_{sample_counter}.pkl'

                with gzip.open(filename, 'wb') as f:
                    pickle.dump({
                        'episode': episode,
                        'instance': instance,
                        'seed': seed,
                        'data': data,
                        }, f)
                out_queue.put({
                    'type': 'sample',
                    'episode': episode,
                    'instance': instance,
                    'seed': seed,
                    'filename': filename,
                })
                sample_counter += 1

            try:
                observation, action_set, _, done, _ = env.step(action)
            except Exception as e:
                done = True
                with open("error_log.txt","a") as f:
                    f.write(f"Error occurred solving {instance} with seed {seed}\n")
                    f.write(f"{e}\n")

        print(f"[w {threading.current_thread().name}] episode {episode} done, {sample_counter} samples\n", end='')
        out_queue.put({
            'type': 'done',
            'episode': episode,
            'instance': instance,
            'seed': seed,
        })


import os
import queue
import threading
import shutil
import re

# def get_last_sample_index(out_dir):
#     """
#     Scans the output directory for existing samples and finds the highest index.
#     """
#     sample_files = [f for f in os.listdir(out_dir) if re.match(r'sample_\d+\.pkl', f)]
#     if sample_files:
#         # Extract numbers from the filenames and find the maximum
#         indices = [int(re.findall(r'\d+', f)[0]) for f in sample_files]
#         return max(indices)
#     else:
#         return 0

def collect_samples(instances, out_dir, rng, n_samples, n_jobs,
                    query_expert_prob, time_limit, last_index):
    """
    Runs branch-and-bound episodes on the given set of instances, and collects
    randomly (state, action) pairs from the 'vanilla-fullstrong' expert
    brancher.
    Parameters
    ----------
    instances : list
        Instance files from which to collect samples.
    out_dir : str
        Directory in which to write samples.
    rng : numpy.random.RandomState
        A random number generator for reproducibility.
    n_samples : int
        Number of samples to collect.
    n_jobs : int
        Number of jobs for parallel sampling.
    query_expert_prob : float in [0, 1]
        Probability of using the expert policy and recording a (state, action)
        pair.
    time_limit : float in [0, 1e+20]
        Maximum running time for an episode, in seconds.
    """
    os.makedirs(out_dir, exist_ok=True)

    # Start index from the last sample file in the directory
    #last_index = 2000
    print(f"Starting from sample index {last_index + 1}")

    # start workers
    orders_queue = queue.Queue(maxsize=2*n_jobs)
    answers_queue = queue.SimpleQueue()

    tmp_samples_dir = f'{out_dir}/tmp'
    os.makedirs(tmp_samples_dir, exist_ok=True)

    # start dispatcher
    dispatcher_stop_flag = threading.Event()
    dispatcher = threading.Thread(
            target=send_orders,
            args=(orders_queue, instances, rng.randint(2**32), query_expert_prob,
                  time_limit, tmp_samples_dir, dispatcher_stop_flag),
            daemon=True)
    dispatcher.start()

    workers = []
    workers_stop_flag = threading.Event()
    for i in range(n_jobs):
        p = threading.Thread(
                target=make_samples,
                args=(orders_queue, answers_queue, workers_stop_flag),
                daemon=True)
        workers.append(p)
        p.start()

    # record answers and write samples
    buffer = {}
    current_episode = 0
    i = 0
    in_buffer = 0
    while i < n_samples:
        sample = answers_queue.get()

        # add received sample to buffer
        if sample['type'] == 'start':
            buffer[sample['episode']] = []
        else:
            buffer[sample['episode']].append(sample)
            if sample['type'] == 'sample':
                in_buffer += 1

        # if any, write samples from current episode
        while current_episode in buffer and buffer[current_episode]:
            samples_to_write = buffer[current_episode]
            buffer[current_episode] = []

            for sample in samples_to_write:

                # if no more samples here, move to next episode
                if sample['type'] == 'done':
                    del buffer[current_episode]
                    current_episode += 1

                # else write sample
                else:
                    sample_index = last_index + i + 1
                    os.rename(sample['filename'], f'{out_dir}/sample_{sample_index}.pkl')
                    in_buffer -= 1
                    i += 1
                    print(f"[m {threading.current_thread().name}] {i} / {n_samples} samples written, "
                          f"ep {sample['episode']} ({in_buffer} in buffer).\n", end='')

                    # early stop dispatcher
                    if in_buffer + i >= n_samples and dispatcher.is_alive():
                        dispatcher_stop_flag.set()
                        print(f"[m {threading.current_thread().name}] dispatcher stopped...\n", end='')

                    # as soon as enough samples are collected, stop
                    if i == n_samples:
                        buffer = {}
                        break

    # stop all workers
    workers_stop_flag.set()
    for p in workers:
        p.join(timeout=15)  # Wait up to 5 seconds for each thread to terminate
        if p.is_alive():
            print(f"Thread {p.name} is still running, skipping.")

    print(f"Done collecting samples for {out_dir}")
    shutil.rmtree(tmp_samples_dir, ignore_errors=True)


In [None]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'problem',
        help='MILP instance type to process.',
        # choices=['setcover', 'cauctions', 'facilities', 'indset', 'mknapsack'],
        choices=['Standard_MTSP', 'MinMax_MTSP', 'Bounded_MTSP', 'JSSP'],

    )
    parser.add_argument(
        '-s', '--seed',
        help='Random generator seed.',
        type=int,
        default=0,
    )
    parser.add_argument(
        '-j', '--njobs',
        help='Number of parallel jobs.',
        type=int,
        default=os.cpu_count(),
    )
    args = parser.parse_args(['JSSP'])

    print(f"seed {args.seed}")

    train_size = 10000
    valid_size = 4000
    node_record_prob = 0.05
    time_limit = 4800

    if args.problem == 'setcover':
        instances_train = glob.glob('data/instances/setcover/train_500r_1000c_0.05d/*.lp')
        instances_valid = glob.glob('data/instances/setcover/valid_500r_1000c_0.05d/*.lp')
        instances_test = glob.glob('data/instances/setcover/test_500r_1000c_0.05d/*.lp')
        out_dir = 'data/samples/setcover/500r_1000c_0.05d'

    elif args.problem == 'cauctions':
        instances_train = glob.glob('data/instances/cauctions/train_100_500/*.lp')
        instances_valid = glob.glob('data/instances/cauctions/valid_100_500/*.lp')
        instances_test = glob.glob('data/instances/cauctions/test_100_500/*.lp')
        out_dir = 'data/samples/cauctions/100_500'

    elif args.problem == 'indset':
        instances_train = glob.glob('data/instances/indset/train_500_4/*.lp')
        instances_valid = glob.glob('data/instances/indset/valid_500_4/*.lp')
        instances_test = glob.glob('data/instances/indset/test_500_4/*.lp')
        out_dir = 'data/samples/indset/500_4'

    elif args.problem == 'facilities':
        instances_train = glob.glob('data/instances/facilities/train_100_100_5/*.lp')
        instances_valid = glob.glob('data/instances/facilities/valid_100_100_5/*.lp')
        instances_test = glob.glob('data/instances/facilities/test_100_100_5/*.lp')
        out_dir = 'data/samples/facilities/100_100_5'
        time_limit = 4800

    elif args.problem == 'mknapsack':
        instances_train = glob.glob('data/instances/mknapsack/train_100_6/*.lp')
        instances_valid = glob.glob('data/instances/mknapsack/valid_100_6/*.lp')
        instances_test = glob.glob('data/instances/mknapsack/test_100_6/*.lp')
        out_dir = 'data/samples/mknapsack/100_6'
        time_limit = 4800

    elif args.problem == 'Standard_MTSP':
        instances_train = glob.glob('data/instances/Standard_MTSP/train_12_3/*.lp')
        instances_valid = glob.glob('data/instances/Standard_MTSP/valid_12_3/*.lp')
        # instances_test = glob.glob('data/instances/Standard_MTSP/test_12_3/*.lp')
        out_dir = 'data/samples/Standard_MTSP/12_3'
        time_limit = 4800

    elif args.problem == 'Bounded_MTSP':
        instances_train = glob.glob('data/instances/Bounded_MTSP/train_12_3/*.lp')
        instances_valid = glob.glob('data/instances/Bounded_MTSP/valid_12_3/*.lp')
        # instances_test = glob.glob('data/instances/Bounded_MTSP/test_12_3/*.lp')
        out_dir = 'data/samples/Bounded_MTSP/12_3'
        time_limit = 4800

    elif args.problem == 'MinMax_MTSP':
        instances_train = glob.glob('data/instances/MinMax_MTSP/train_9_3/*.lp')
        instances_valid = glob.glob('data/instances/MinMax_MTSP/valid_9_3/*.lp')
        # instances_test = glob.glob('data/instances/MinMax_MTSP/test_12_3/*.lp')
        out_dir = 'data/samples/MinMax_MTSP/9_3'
        time_limit = 4800

    elif args.problem == 'JSSP':
        instances_train = glob.glob('data/instances/JSSP/train_6_3/*.lp')
        instances_valid = glob.glob('data/instances/JSSP/valid_6_3/*.lp')
        instances_test = glob.glob('data/instances/JSSP/test_6_3/*.lp')
        out_dir = 'data/samples/JSSP/6_3'
        time_limit = 4800

    else:
        raise NotImplementedError

    print(f"{len(instances_train)} train instances for {train_size} samples")
    print(f"{len(instances_valid)} validation instances for {valid_size} samples")
    # print(f"{len(instances_test)} test instances for {test_size} samples")

    os.makedirs(out_dir, exist_ok=True)

    rng = np.random.RandomState(args.seed + 3)
    collect_samples(instances_train, out_dir + '/train/1', rng, train_size,
                    args.njobs, query_expert_prob=node_record_prob,
                    time_limit=time_limit, last_index = 0)

    rng = np.random.RandomState(args.seed + 4)
    collect_samples(instances_train, out_dir + '/train/2', rng, train_size,
                    args.njobs, query_expert_prob=node_record_prob,
                    time_limit=time_limit, last_index = 0)

    rng = np.random.RandomState(args.seed + 1)
    collect_samples(instances_valid, out_dir + '/valid', rng, valid_size,
                    args.njobs, query_expert_prob=node_record_prob,
                    time_limit=time_limit, last_index = 0)

    # rng = np.random.RandomState(args.seed + 2)
    # collect_samples(instances_test, out_dir + '/test', rng, test_size,
    #                 args.njobs, query_expert_prob=node_record_prob,
    #                 time_limit=time_limit, last_index = 0)

# Train GNN

In [None]:
import os
import sys
import argparse
import pathlib

In [None]:
args = argparse.Namespace(gpu=0)  # Set to -1 for CPU, or specify GPU ID

def set_device(gpu):
    if gpu == -1 or not torch.cuda.is_available():
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
        return torch.device("cpu")  # Use CPU
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = f'{gpu}'
        return torch.device("cuda:0")  # Use the specified GPU

device = set_device(args.gpu)
print(f"Using device: {device}")

In [None]:
# Define the compute_avg_d_separated function
def compute_avg_d_separated(loader, device=device):
    """
    Computes the average of log(degree + 1) separately for left and right nodes
    in bipartite graphs within the dataset.

    Parameters:
    - loader (torch_geometric.loader.DataLoader): DataLoader for the training dataset.
    - device (torch.device): Device to perform computations on.

    Returns:
    - avg_log_left (float): Average of log(degree + 1) for left nodes.
    - avg_log_right (float): Average of log(degree + 1) for right nodes.
    """
    log_d_sum_left = 0.0
    log_d_sum_right = 0.0
    left_count = 0
    right_count = 0
    for data in loader:
        data = data.to(device)

        # Number of left and right nodes
        num_left = data.constraint_features.size(0)
        num_right = data.variable_features.size(0)

        # Compute degrees for left nodes (sources in edge_index[0])
        left_degrees = torch.bincount(
            data.edge_index[0],
            minlength=num_left
        ).float()

        # Compute degrees for right nodes (targets in edge_index[1])
        right_degrees = torch.bincount(
            data.edge_index[1],
            minlength=num_right
        ).float()

        # Handle nodes with no connections by ensuring minlength
        left_degrees = left_degrees.clamp_min(0)  # No negative degrees
        right_degrees = right_degrees.clamp_min(0)

        # Compute log(d + 1) for each node
        log_degrees_left = torch.log(left_degrees + 1)  # [num_left]
        log_degrees_right = torch.log(right_degrees + 1)  # [num_right]

        # Accumulate sum and count
        log_d_sum_left += torch.sum(log_degrees_left).item()
        log_d_sum_right += torch.sum(log_degrees_right).item()
        left_count += num_left
        right_count += num_right

    # Compute average log degree for left and right nodes
    avg_log_left = log_d_sum_left / left_count if left_count > 0 else 1.0  # Avoid division by zero
    avg_log_right = log_d_sum_right / right_count if right_count > 0 else 1.0

    return avg_log_left, avg_log_right

## Pre-Train

In [None]:

import multiprocessing as mp
import torch.nn.utils as nn_utils
print(mp.cpu_count())

def pretrain(policy, pretrain_loader):
    policy.pre_train_init()
    i = 0
    while True:
        for batch in pretrain_loader:
            batch.to(device)
            if not policy.pre_train(batch.constraint_features, batch.edge_index, batch.edge_attr, batch.variable_features):
                break

        if policy.pre_train_next() is None:
            break
        i += 1
    return i

def process(policy, data_loader, top_k = [1, 3, 5, 10], optimizer=None):
    mean_loss = 0
    mean_kacc = np.zeros(len(top_k))
    mean_entropy = 0

    n_samples_processed = 0
    with torch.set_grad_enabled(optimizer is not None):
        for batch in data_loader:
            batch = batch.to(device)
            logits = policy(batch.constraint_features, batch.edge_index, batch.edge_attr, batch.variable_features)
            logits = pad_tensor(logits[batch.candidates], batch.nb_candidates)

            # if torch.isnan(logits).any():
            #     print("NaN detected in logits!")

            cross_entropy_loss = F.cross_entropy(logits, batch.candidate_choices, reduction='mean')
            entropy = (-F.softmax(logits, dim=-1)*F.log_softmax(logits, dim=-1)).sum(-1).mean()
            loss = cross_entropy_loss - entropy_bonus*entropy

            if optimizer is not None:
                optimizer.zero_grad()

                loss.backward()
                # nn_utils.clip_grad_norm_(policy.parameters(), 5.0)

                optimizer.step()

            true_scores = pad_tensor(batch.candidate_scores, batch.nb_candidates)
            true_bestscore = true_scores.max(dim=-1, keepdims=True).values

            kacc = []
            for k in top_k:
                if logits.size()[-1] < k:
                    kacc.append(1.0)
                    continue
                pred_top_k = logits.topk(k).indices
                pred_top_k_true_scores = true_scores.gather(-1, pred_top_k)
                accuracy = (pred_top_k_true_scores == true_bestscore).any(dim=-1).float().mean().item()
                kacc.append(accuracy)
            kacc = np.asarray(kacc)
            mean_loss += cross_entropy_loss.item() * batch.num_graphs
            mean_entropy += entropy.item() * batch.num_graphs
            mean_kacc += kacc * batch.num_graphs
            n_samples_processed += batch.num_graphs

    mean_loss /= n_samples_processed
    mean_kacc /= n_samples_processed
    mean_entropy /= n_samples_processed
    return mean_loss, mean_kacc, mean_entropy

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'problem',
        help='MILP instance type to process.',
        # choices=['setcover', 'cauctions', 'facilities', 'indset', 'mknapsack'],

        choices=['Standard_MTSP', 'MinMax_MTSP', 'Bounded_MTSP', 'JSSP'],
    )
    parser.add_argument(
        '-s', '--seed',
        help='Random generator seed.',
        type=int,
        default=0,
    )
    parser.add_argument(
        '-g', '--gpu',
        help='CUDA GPU id (-1 for CPU).',
        type=int,
        default=0,
    )
    args = parser.parse_args(['Standard_MTSP'])

    ### HYPER PARAMETERS ###
    max_epochs = 0
    batch_size = 32
    pretrain_batch_size = 128
    valid_batch_size = 128
    lr = 1e-3
    entropy_bonus = 0.0
    top_k = [1, 3, 5, 10]

    problem_folders = {
        'setcover': 'setcover/500r_1000c_0.05d',
        'cauctions': 'cauctions/100_500',
        'facilities': 'facilities/100_100_5',
        'indset': 'indset/500_4',
        'mknapsack': 'mknapsack/100_6',
        'Standard_MTSP': 'Standard_MTSP/12_3',
        'JSSP': 'JSSP/6_3',
        'MinMax_MTSP': 'MinMax_MTSP/9_3',
        'Bounded_MTSP': 'Bounded_MTSP/12_3',
    }

    problem_folder = problem_folders[args.problem]
    running_dir = f"model/{args.problem}/{args.seed}"
    os.makedirs(running_dir, exist_ok=True)


    sys.path.insert(0, os.path.abspath(f'model'))

    rng = np.random.RandomState(args.seed)
    torch.manual_seed(args.seed)

    ### LOG ###
    logfile = os.path.join(running_dir, 'pretrain_log.txt')
    if os.path.exists(logfile):
        os.remove(logfile)

    log(f"max_epochs: {max_epochs}", logfile)
    log(f"batch_size: {batch_size}", logfile)
    log(f"pretrain_batch_size: {pretrain_batch_size}", logfile)
    log(f"valid_batch_size : {valid_batch_size }", logfile)
    log(f"lr: {lr}", logfile)
    log(f"entropy bonus: {entropy_bonus}", logfile)
    log(f"top_k: {top_k}", logfile)
    log(f"problem: {args.problem}", logfile)
    log(f"gpu: {args.gpu}", logfile)
    log(f"seed {args.seed}", logfile)

    base_path = pathlib.Path('data/samples')

    train_files = []

    folders = ['1', '2']

    for folder in folders:
        folder_path = base_path / problem_folder / 'train' / folder

        files = list(folder_path.glob('sample_*.pkl'))

        train_files.extend([str(file) for file in files])

        print(f"Loaded {len(files)} files from folder '{folder_path}'.")

        if folder == '1':
            print("Waiting for 1 seconds before processing the next folder...")
            time.sleep(1)

    pretrain_files = [f for i, f in enumerate(train_files) if i % 2 == 0]
    avg_d_files = [f for i, f in enumerate(train_files) if i % 200 == 0]
    valid_files = [str(file) for file in (pathlib.Path(f'data/samples')/problem_folder/'valid').glob('sample_*.pkl')]
    pretrain_data = GraphDataset(pretrain_files)
    avg_d_data = GraphDataset(avg_d_files)

    pretrain_loader = torch_geometric.loader.DataLoader(pretrain_data, pretrain_batch_size, shuffle=False, num_workers=mp.cpu_count())
    avg_d_loader = torch_geometric.loader.DataLoader(avg_d_data, batch_size, shuffle=False, num_workers=mp.cpu_count())

    valid_data = GraphDataset(valid_files)
    valid_loader = torch_geometric.loader.DataLoader(valid_data, valid_batch_size, shuffle=False, num_workers=mp.cpu_count())

    ### COMPUTE avg_d ###
    avg_log_left, avg_log_right = compute_avg_d_separated(avg_d_loader, device=device)
    log(f"Computed avg_d: Left = {avg_log_left:.4f}, Right = {avg_log_right:.4f}", logfile)

    policy = GNNPolicy(avg_d_left=avg_log_left, avg_d_right=avg_log_right).to(device)

    optimizer = torch.optim.Adam(policy.parameters(), lr=1e-3)
    scheduler = Scheduler(optimizer, mode='min', patience=10, factor=0.2, verbose=True)

    for epoch in range(max_epochs + 1):
        log(f"EPOCH {epoch}...", logfile)
        if epoch == 0:
            n = pretrain(policy, pretrain_loader)
            log(f"PRETRAINED {n} LAYERS", logfile)
        else:
            epoch_train_files = rng.choice(train_files, int(np.floor(10000/batch_size))*batch_size, replace=True)
            train_data = GraphDataset(epoch_train_files)
            train_loader = torch_geometric.data.DataLoader(train_data, batch_size, shuffle=True, num_workers=mp.cpu_count())
            train_loss, train_kacc, entropy = process(policy, train_loader, top_k, optimizer)
            log(f"TRAIN LOSS: {train_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, train_kacc)]), logfile)

        # TEST
        valid_loss, valid_kacc, entropy = process(policy, valid_loader, top_k, None)
        log(f"VALID LOSS: {valid_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, valid_kacc)]), logfile)

        scheduler.step(valid_loss)
        if scheduler.num_bad_epochs == 0:
            torch.save(policy.state_dict(), pathlib.Path(running_dir)/'Standard_MTSP_Org_pretrain_PARAM.pkl')
            log(f"  best model so far", logfile)
        elif scheduler.num_bad_epochs == 10:
            log(f"  10 epochs without improvement, decreasing learning rate", logfile)
        elif scheduler.num_bad_epochs == 20:
            log(f"  20 epochs without improvement, early stopping", logfile)
            break


## Training Loop

In [None]:
import argparse
import os
import sys
import pathlib
import matplotlib.pyplot as plt
torch.autograd.set_detect_anomaly(True)

def process(policy, data_loader, top_k = [1, 3, 5, 10], optimizer=None):
    mean_loss = 0
    mean_kacc = np.zeros(len(top_k))
    mean_entropy = 0

    n_samples_processed = 0
    with torch.set_grad_enabled(optimizer is not None):
        for batch in data_loader:
            batch = batch.to(device)
            logits = policy(batch.constraint_features, batch.edge_index, batch.edge_attr, batch.variable_features)
            logits = pad_tensor(logits[batch.candidates], batch.nb_candidates)

           # if torch.isnan(logits).any():
           #     print("NaN detected in logits!")


            cross_entropy_loss = F.cross_entropy(logits, batch.candidate_choices, reduction='mean')
            entropy = (-F.softmax(logits, dim=-1)*F.log_softmax(logits, dim=-1)).sum(-1).mean()
            loss = cross_entropy_loss - entropy_bonus*entropy

            if optimizer is not None:
                optimizer.zero_grad()

                loss.backward()
                # nn_utils.clip_grad_norm_(policy.parameters(), 5.0)

                optimizer.step()

            true_scores = pad_tensor(batch.candidate_scores, batch.nb_candidates)
            true_bestscore = true_scores.max(dim=-1, keepdims=True).values

            kacc = []
            for k in top_k:
                if logits.size()[-1] < k:
                    kacc.append(1.0)
                    continue
                pred_top_k = logits.topk(k).indices
                pred_top_k_true_scores = true_scores.gather(-1, pred_top_k)
                accuracy = (pred_top_k_true_scores == true_bestscore).any(dim=-1).float().mean().item()
                kacc.append(accuracy)
            kacc = np.asarray(kacc)
            mean_loss += cross_entropy_loss.item() * batch.num_graphs
            mean_entropy += entropy.item() * batch.num_graphs
            mean_kacc += kacc * batch.num_graphs
            n_samples_processed += batch.num_graphs

    mean_loss /= n_samples_processed
    mean_kacc /= n_samples_processed
    mean_entropy /= n_samples_processed
    return mean_loss, mean_kacc, mean_entropy


# from utilities import log, pad_tensor, GraphDataset, Scheduler
# from model import GNNPolicy

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'problem',
        help='MILP instance type to process.',
        # choices=['setcover', 'cauctions', 'facilities', 'indset', 'mknapsack'],

        choices=['Standard_MTSP', 'MinMax_MTSP', 'Bounded_MTSP', 'JSSP'],
    )
    parser.add_argument(
        '-s', '--seed',
        help='Random generator seed.',
        type=int,
        default=0,
    )
    parser.add_argument(
        '-g', '--gpu',
        help='CUDA GPU id (-1 for CPU).',
        type=int,
        default=0,
    )
    args = parser.parse_args(['Standard_MTSP'])  # Adjust arguments as needed

    ### HYPER PARAMETERS ###
    max_epochs = 1000
    batch_size = 32
    pretrain_batch_size = 128
    valid_batch_size = 128
    lr = 1e-3
    entropy_bonus = 0.0
    top_k = [1, 3, 5, 10]

    problem_folders = {
        'setcover': 'setcover/500r_1000c_0.05d',
        'cauctions': 'cauctions/100_500',
        'facilities': 'facilities/100_100_5',
        'indset': 'indset/500_4',
        'mknapsack': 'mknapsack/100_6',
        'MinMax_MTSP': 'MinMax_MTSP/9_3',
        'Standard_MTSP': 'Standard_MTSP/12_3',
        'JSSP': 'JSSP/6_3',
        'Bounded_MTSP': 'Bounded_MTSP/12_3',
    }
    problem_folder = problem_folders[args.problem]
    running_dir = f"model/{args.problem}/{args.seed}"
    os.makedirs(running_dir, exist_ok=True)

    sys.path.insert(0, os.path.abspath(f'model'))

    rng = np.random.RandomState(args.seed)
    torch.manual_seed(args.seed)

    ### LOG ###
    logfile = os.path.join(running_dir, 'Standard_MTSP_PNA_plot_train_log.txt')
    if not os.path.exists(logfile):
        with open(logfile, 'w') as f:
            f.write("")

    log(f"max_epochs: {max_epochs}", logfile)
    log(f"batch_size: {batch_size}", logfile)
    log(f"pretrain_batch_size: {pretrain_batch_size}", logfile)
    log(f"valid_batch_size : {valid_batch_size}", logfile)
    log(f"lr: {lr}", logfile)
    log(f"entropy bonus: {entropy_bonus}", logfile)
    log(f"top_k: {top_k}", logfile)
    log(f"problem: {args.problem}", logfile)
    log(f"gpu: {args.gpu}", logfile)
    log(f"seed {args.seed}", logfile)

    base_path = pathlib.Path('data/samples')
    train_files = []

    folders = ['1', '2']

    for folder in folders:
        folder_path = base_path / problem_folder / 'train' / folder

        files = list(folder_path.glob('sample_*.pkl'))

        train_files.extend([str(file) for file in files])

        print(f"Loaded {len(files)} files from folder '{folder_path}'.")

        if folder == '1':
            print("Waiting for 1 seconds before processing the next folder...")
            time.sleep(1)

    valid_files = [str(file) for file in (pathlib.Path(f'data/samples')/problem_folder/'valid').glob('sample_*.pkl')]
 
    avg_d_files = [f for i, f in enumerate(train_files) if i % 200 == 0]
    avg_d_data = GraphDataset(avg_d_files)
    avg_d_loader = torch_geometric.loader.DataLoader(avg_d_data, batch_size, shuffle=False, num_workers=mp.cpu_count())

    valid_data = GraphDataset(valid_files)
    valid_loader = torch_geometric.loader.DataLoader(valid_data, valid_batch_size, shuffle=False, num_workers=mp.cpu_count()-1)

    ## COMPUTE avg_d ###
    avg_log_left, avg_log_right = compute_avg_d_separated(avg_d_loader, device=device)
    log(f"Computed avg_d: Left = {avg_log_left:.4f}, Right = {avg_log_right:.4f}", logfile)


    policy = GNNPolicy(avg_d_left=avg_log_left, avg_d_right=avg_log_right).to(device)

    # checkpoint_path = pathlib.Path(running_dir) / 'JSSP_Org_pretrain_PARAM.pkl'
    # if checkpoint_path.exists():
    #     policy.load_state_dict(torch.load(checkpoint_path, map_location=device))
    #     log(f"Loaded model parameters from {checkpoint_path}", logfile)
    # else:
    #     log(f"No saved model found at {checkpoint_path}. Exiting.", logfile)
    #     sys.exit(1)

    optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
    scheduler = Scheduler(optimizer, mode='min', patience=5, factor=0.25)

    starting_epoch = 1

    log(f"Resuming training from epoch {starting_epoch}...", logfile)

    # Prepare to record metrics
    valid_losses = []
    valid_accs = {k: [] for k in top_k}

    for epoch in range(starting_epoch, max_epochs + 1):
        log(f"EPOCH {epoch}...", logfile)

        for train_step in range(2):
            epoch_train_files = rng.choice(train_files, int(np.floor(10000/ batch_size)) * batch_size, replace=True)
            epoch_train_data = GraphDataset(epoch_train_files)
            epoch_train_loader = torch_geometric.loader.DataLoader(epoch_train_data, batch_size, shuffle=True, num_workers=mp.cpu_count())

            train_loss, train_kacc, entropy = process(
                policy,
                epoch_train_loader,
                top_k,
                optimizer
            )

            log(f"Train Step [{train_step+1}/2] - TRAIN LOSS: {train_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, train_kacc)]), logfile)

        # VALIDATION
        valid_loss, valid_kacc, entropy = process(
            policy,
            valid_loader,
            top_k,
            optimizer=None
        )
        # record validation metrics
        valid_losses.append(valid_loss)
        for k_i, k in enumerate(top_k):
            valid_accs[k].append(valid_kacc[k_i])

        log(f"VALID LOSS: {valid_loss:0.3f} " +
            "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, valid_kacc)]),
            logfile)

        scheduler.step(valid_loss)
        if scheduler.num_bad_epochs == 0:
            torch.save(policy.state_dict(),
                       pathlib.Path(running_dir)/f'{args.problem}_best.pkl')
            log(f"  Best model so far", logfile)
        elif scheduler.num_bad_epochs == 10:
            log(f"  Early stopping after 10 epochs without improvement", logfile)
            break

# Evaluate

In [None]:
import os
import sys
import importlib
import argparse
import csv
import time
import pickle
import ecole
import pyscipopt

In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'problem',
        help='MILP instance type to process.',
        # choices=['setcover', 'cauctions', 'facilities', 'indset'],

        choices=['Standard_MTSP', 'MinMax_MTSP', 'Bounded_MTSP', 'JSSP'],
    )
    parser.add_argument(
        '-g', '--gpu',
        help='CUDA GPU id (-1 for CPU).',
        type=int,
        default=0,
    )
    args = parser.parse_args(['Standard_MTSP'])

    result_file = f"{args.problem}_{time.strftime('%Y%m%d-%H%M%S')}.csv"
    instances = []
    seeds = [1, 2]

    gnn_models = ['pna_gnn']
    time_limit = 200

    # Define instances based on the selected problem
    if args.problem == 'setcover':
        instances += [{'type': 'small', 'path': f"data/instances/setcover/transfer_500r_1000c_0.05d/instance_{i+1}.lp"} for i in range(20)]
        instances += [{'type': 'medium', 'path': f"data/instances/setcover/transfer_1000r_1000c_0.05d/instance_{i+1}.lp"} for i in range(20)]
        instances += [{'type': 'big', 'path': f"data/instances/setcover/transfer_2000r_1000c_0.05d/instance_{i+1}.lp"} for i in range(20)]

    elif args.problem == 'cauctions':
        instances += [{'type': 'small', 'path': f"data/instances/cauctions/transfer_100_500/instance_{i+1}.lp"} for i in range(20)]
        instances += [{'type': 'medium', 'path': f"data/instances/cauctions/transfer_200_1000/instance_{i+1}.lp"} for i in range(20)]
        instances += [{'type': 'big', 'path': f"data/instances/cauctions/transfer_300_1500/instance_{i+1}.lp"} for i in range(20)]

    elif args.problem == 'facilities':
        instances += [{'type': 'small', 'path': f"data/instances/facilities/transfer_100_100_5/instance_{i+6}.lp"} for i in range(10)]
        # Removed medium and big instances by setting range to 0
        instances += [{'type': 'medium', 'path': f"data/instances/facilities/transfer_200_100_5/instance_{i+1}.lp"} for i in range(5)]
        instances += [{'type': 'big', 'path': f"data/instances/facilities/transfer_400_100_5/instance_{i+1}.lp"} for i in range(0)]

    elif args.problem == 'indset':
        instances += [{'type': 'small', 'path': f"data/instances/indset/transfer_500_4/instance_{i+1}.lp"} for i in range(20)]
        instances += [{'type': 'medium', 'path': f"data/instances/indset/transfer_1000_4/instance_{i+1}.lp"} for i in range(20)]
        instances += [{'type': 'big', 'path': f"data/instances/indset/transfer_1500_4/instance_{i+1}.lp"} for i in range(20)]

    elif args.problem == 'MinMax_MTSP':
        instances += [{'type': 'small', 'path': f"data/instances/MinMax_MTSP/train_9_3/instance_{i+1}.lp"} for i in range(25)]

    elif args.problem == 'JSSP':
        instances += [{'type': 'small', 'path': f"data/instances/JSSP/train_6_3/instance_{i+1}.lp"} for i in range(20)]

    else:
        raise NotImplementedError(f"Problem type '{args.problem}' is not implemented.")

    branching_policies = []

    for model in gnn_models:
        for seed in seeds:
            branching_policies.append({
                'type': 'gnn',
                'name': model,
                'seed': seed,
            })

    print(f"problem: {args.problem}")
    print(f"gpu: {args.gpu}")
    print(f"time limit: {time_limit} s")

    ### PYTORCH SETUP ###
    if args.gpu == -1:
        os.environ['CUDA_VISIBLE_DEVICES'] = ''
        device = 'cpu'
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpu}'
        device = f"cuda:{0}"  # Ensure correct CUDA device indexing

    # Initialize device
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load and assign models to policies (share models and update parameters)
    loaded_models = {}
    for policy in branching_policies:
        if policy['type'] == 'gnn':
            if policy['name'] not in loaded_models:
                ### MODEL LOADING ###
                if policy['name'] == 'pna_gnn':
                    model = GNNPolicy(avg_d_left=1.35, avg_d_right=1.54).to(device) # to be adjusted according to the degree scaler

                    model_path = "/content/drive/MyDrive/Thesis/model/JSSP/0/JSSP_PNA_train_PARAM.pkl"
                # elif policy['name'] == 'org_gnn':
                #     model = OrgGNNPolicy().to(device)
                #     model_path = "/content/drive/MyDrive/Thesis/model/JSSP/0/JSSP_Org_train_PARAM.pkl"
                # elif policy['name'] == 'att_gnn':
                #     model = AttGNNPolicy().to(device)
                #     model_path = "/content/drive/MyDrive/Thesis/model/JSSP/0/JSSP_Att_pretrain_PARAM.pkl"

                else:
                    raise Exception(f"Unrecognized GNN policy {policy['name']}")

                model.load_state_dict(torch.load(model_path, map_location=device))
                loaded_models[policy['name']] = model

            policy['model'] = loaded_models[policy['name']]

    print("Running SCIP...")

    fieldnames = [
        'policy',
        'seed',
        'type',
        'instance',
        'nnodes',
        'nlps',
        'stime',
        'gap',
        'status',
        'walltime',
        'proctime',
    ]
    os.makedirs('results', exist_ok=True)
    scip_parameters = {
        'separating/maxrounds': 0,
        'presolving/maxrestarts': 0,
        'limits/time': time_limit,
        'timing/clocktype': 1,
        'branching/vanillafullstrong/idempotent': True
    }

    with open(f"results/{result_file}", 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        for instance in instances:
            print(f"{instance['type']}: {instance['path']}...")

            for policy in branching_policies:
                if policy['type'] == 'gnn':
                    # Run the GNN policy
                    env = ecole.environment.Branching(
                        observation_function=ecole.observation.NodeBipartite(),
                        scip_params=scip_parameters
                    )
                    env.seed(policy['seed'])
                    torch.manual_seed(policy['seed'])

                    walltime = time.perf_counter()
                    proctime = time.process_time()

                    observation, action_set, _, done, _ = env.reset(instance['path'])
                    while not done:
                        with torch.no_grad():
                            # Prepare observation tensors
                            observation_tensor = (
                                torch.from_numpy(observation.row_features.astype(np.float32)).to(device),
                                torch.from_numpy(observation.edge_features.indices.astype(np.int64)).to(device),
                                torch.from_numpy(observation.edge_features.values.astype(np.float32)).view(-1, 1).to(device),
                                torch.from_numpy(observation.variable_features.astype(np.float32)).to(device)
                            )

                            # Get logits from the model
                            logits = policy['model'](*observation_tensor)
                            # Select the action with the highest logit
                            action = action_set[logits[action_set.astype(np.int64)].argmax()]
                            # Step the environment with the selected action
                            observation, action_set, _, done, _ = env.step(action)

                    walltime = time.perf_counter() - walltime
                    proctime = time.process_time() - proctime

                    # Retrieve SCIP model metrics
                    scip_model = env.model.as_pyscipopt()
                    stime = scip_model.getSolvingTime()
                    nnodes = scip_model.getNNodes()
                    nlps = scip_model.getNLPs()
                    gap = scip_model.getGap()
                    status = scip_model.getStatus()

                    # Write results
                    writer.writerow({
                        'policy': f"{policy['type']}:{policy['name']}",
                        'seed': policy['seed'],
                        'type': instance['type'],
                        'instance': instance['path'],
                        'nnodes': nnodes,
                        'nlps': nlps,
                        'stime': stime,
                        'gap': gap,
                        'status': status,
                        'walltime': walltime,
                        'proctime': proctime,
                    })
                    csvfile.flush()

                    print(f"  {policy['type']}:{policy['name']} {policy['seed']} - {nnodes} nodes {nlps} lps {stime:.2f} ({walltime:.2f} wall {proctime:.2f} proc) s. {status}")

                    # Check if time limit is exceeded
                    if stime > time_limit:
                        print(f"  {policy['type']}:{policy['name']} exceeded time limit ({stime:.2f} > {time_limit}). Skipping to next instance.")
                        break  # Skip remaining seeds for this instance



In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'problem',
        help='MILP instance type to process.',
        # choices=['setcover', 'cauctions', 'facilities', 'indset'],

        choices=['Standard_MTSP', 'MinMax_MTSP', 'Bounded_MTSP', 'JSSP'],
    )
    parser.add_argument(
        '-g', '--gpu',
        help='CUDA GPU id (-1 for CPU).',
        type=int,
        default=0,
    )
    args = parser.parse_args(['JSSP'])

    result_file = f"{args.problem}_{time.strftime('%Y%m%d-%H%M%S')}.csv"
    instances = []
    seeds = [1,2]
    internal_branchers = ['relpscost']
    gnn_models = [''] # Can be supervised
    time_limit = 300

    if args.problem == 'setcover':
        instances += [{'type': 'small', 'path': f"data/instances/setcover/transfer_500r_1000c_0.05d/instance_{i+1}.lp"} for i in range(20)]
        instances += [{'type': 'medium', 'path': f"data/instances/setcover/transfer_1000r_1000c_0.05d/instance_{i+1}.lp"} for i in range(20)]
        instances += [{'type': 'big', 'path': f"data/instances/setcover/transfer_2000r_1000c_0.05d/instance_{i+1}.lp"} for i in range(20)]

    elif args.problem == 'cauctions':
        instances += [{'type': 'small', 'path': f"data/instances/cauctions/transfer_100_500/instance_{i+1}.lp"} for i in range(20)]
        instances += [{'type': 'medium', 'path': f"data/instances/cauctions/transfer_200_1000/instance_{i+1}.lp"} for i in range(20)]
        instances += [{'type': 'big', 'path': f"data/instances/cauctions/transfer_300_1500/instance_{i+1}.lp"} for i in range(20)]

    elif args.problem == 'facilities':
        instances += [{'type': 'small', 'path': f"data/instances/facilities/transfer_100_100_5/instance_{i+3}.lp"} for i in range(3)]
        instances += [{'type': 'medium', 'path': f"data/instances/facilities/transfer_200_100_5/instance_{i+3}.lp"} for i in range(0)]
        instances += [{'type': 'big', 'path': f"data/instances/facilities/transfer_400_100_5/instance_{i+3}.lp"} for i in range(0)]

    elif args.problem == 'indset':
        instances += [{'type': 'small', 'path': f"data/instances/indset/transfer_500_4/instance_{i+1}.lp"} for i in range(20)]
        instances += [{'type': 'medium', 'path': f"data/instances/indset/transfer_1000_4/instance_{i+1}.lp"} for i in range(20)]
        instances += [{'type': 'big', 'path': f"data/instances/indset/transfer_1500_4/instance_{i+1}.lp"} for i in range(20)]

    elif args.problem == 'Bounded_MTSP':
        instances += [{'type': 'small', 'path': f"data/instances/Bounded_MTSP/train_12_3/instance_{i+1}.lp"} for i in range(20)]

    elif args.problem == 'JSSP':
        instances += [{'type': 'small', 'path': f"data/instances/JSSP/train_8_4/instance_{i+1}.lp"} for i in range(20)]

    else:
        raise NotImplementedError

    branching_policies = []

    # SCIP internal brancher baselines
    for brancher in internal_branchers:
        for seed in seeds:
            branching_policies.append({
                    'type': 'internal',
                    'name': brancher,
                    'seed': seed,
             })
    # GNN models
    for model in gnn_models:
        for seed in seeds:
            branching_policies.append({
                'type': 'gnn',
                'name': model,
                'seed': seed,
            })

    print(f"problem: {args.problem}")
    print(f"gpu: {args.gpu}")
    print(f"time limit: {time_limit} s")

    # ### PYTORCH SETUP ###
    # if args.gpu == -1:
    #     os.environ['CUDA_VISIBLE_DEVICES'] = ''
    #     device = 'cpu'
    # else:
    #     os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpu}'
    #     device = f"cuda:0"

    #from model.model import GNNPolicy

    # load and assign tensorflow models to policies (share models and update parameters)
    loaded_models = {}
    loaded_calls = {}
    # for policy in branching_policies:
    #     if policy['type'] == 'gnn':
    #         if policy['name'] not in loaded_models:
    #             ### MODEL LOADING ###
    #             model = GNNPolicy().to(device)
    #             if policy['name'] == 'supervised':
    #                 # model.load_state_dict(torch.load(f"model/{args.problem}/{policy['seed']}/train_params.pkl"))
    #                 model.load_state_dict(torch.load(f"/content/drive/MyDrive/Thesis/model/facilities/0/facilities_multi_train_params.pkl"))
    #             else:
    #                 raise Exception(f"Unrecognized GNN policy {policy['name']}")
    #             loaded_models[policy['name']] = model

    #         policy['model'] = loaded_models[policy['name']]

    print("running SCIP...")

    fieldnames = [
        'policy',
        'seed',
        'type',
        'instance',
        'nnodes',
        'nlps',
        'stime',
        'gap',
        'status',
        'walltime',
        'proctime',
    ]
    os.makedirs('results', exist_ok=True)
    scip_parameters = {'separating/maxrounds': 0, 'presolving/maxrestarts': 0, 'limits/time': time_limit,
                       'timing/clocktype': 1, 'branching/vanillafullstrong/idempotent': True}

    with open(f"results/{result_file}", 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        for instance in instances:
            print(f"{instance['type']}: {instance['path']}...")

            for policy in branching_policies:
                if policy['type'] == 'internal':
                    # Run SCIP's default brancher
                    env = ecole.environment.Configuring(scip_params={**scip_parameters,
                                                        f"branching/{policy['name']}/priority": 9999999})
                    env.seed(policy['seed'])

                    walltime = time.perf_counter()
                    proctime = time.process_time()

                    env.reset(instance['path'])
                    _, _, _, _, _ = env.step({})

                    walltime = time.perf_counter() - walltime
                    proctime = time.process_time() - proctime

                scip_model = env.model.as_pyscipopt()
                stime = scip_model.getSolvingTime()
                nnodes = scip_model.getNNodes()
                nlps = scip_model.getNLPs()
                gap = scip_model.getGap()
                status = scip_model.getStatus()

                writer.writerow({
                    'policy': f"{policy['type']}:{policy['name']}",
                    'seed': policy['seed'],
                    'type': instance['type'],
                    'instance': instance['path'],
                    'nnodes': nnodes,
                    'nlps': nlps,
                    'stime': stime,
                    'gap': gap,
                    'status': status,
                    'walltime': walltime,
                    'proctime': proctime,
                })
                csvfile.flush()

                print(f"  {policy['type']}:{policy['name']} {policy['seed']} - {nnodes} nodes {nlps} lps {stime:.2f} ({walltime:.2f} wall {proctime:.2f} proc) s. {status}")
