# Cutting planes enhanced by GCN models and Q-learning for branching

In [None]:
!pip install dgl

Collecting dgl
[?25l  Downloading https://files.pythonhosted.org/packages/c5/b4/84e4ebd70ef3985181ef5d2d2a366a45af0e3cd18d249fb212ac03f683cf/dgl-0.4.3.post2-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)
[K     |████████████████████████████████| 3.0MB 4.5MB/s 
Installing collected packages: dgl
Successfully installed dgl-0.4.3.post2


In [None]:
!pip install pulp

Collecting pulp
[?25l  Downloading https://files.pythonhosted.org/packages/41/34/757c88c320f80ce602199603afe63aed1e0bc11180b9a9fb6018fb2ce7ef/PuLP-2.1-py3-none-any.whl (40.6MB)
[K     |████████████████████████████████| 40.6MB 121kB/s 
Installing collected packages: pulp
Successfully installed pulp-2.1


In [None]:
import numpy as np
import pandas as pd 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch as th

from copy import deepcopy
from datetime import datetime
import random 
from collections import deque
from dgl import DGLGraph
import seaborn as sns
import matplotlib.pyplot as plt 
import networkx as nx
from sklearn.manifold import TSNE
from google.colab import files

DGL backend not selected or invalid.  Assuming PyTorch for now.
Using backend: pytorch


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


  import pandas.util.testing as tm


In [None]:
# mijn functions
import actions_v12 as act
import graph_models_v17 as gm

Valid inequalities version: 9
No hard restrictions for teeth assignments


**Champion GCN models**

In [None]:
# Models
GCN1 = gm.GCN5_MN(4, 256, 8, 0.2)
GCN1.load_state_dict(th.load("GCN1_JUN3.pth.tar"))
GCN1.eval()

GCN2 = gm.GCN5_MX(4, 256, 8, 0.2)
GCN2.load_state_dict(th.load("GCN2_JUN3.pth.tar"))
GCN2.eval()

GCN3 = gm.GCN5_SU(4, 256, 8, 0.2)
GCN3.load_state_dict(th.load("GCN3_JUN3.pth.tar"))
GCN3.eval()

GAT1 = gm.GAT3_MN(4, 16, 8, 8, 0.2)
GAT1.load_state_dict(th.load("GAT1_JUN4.pth.tar"))
GAT1.eval()

GAT2 = gm.GAT3_MX(4, 16, 8, 8, 0.2)
GAT2.load_state_dict(th.load("GAT2_JUN4.pth.tar"))
GAT2.eval()

GAT3 = gm.GAT3_SU(4, 16, 8, 8, 0.2)
GAT3.load_state_dict(th.load("GAT3_JUN4.pth.tar"))
GAT3.eval()

# Thresholds
# 0,1,2,3,4,5,7,8
tholds = [0.449, 0.343, 0.377, 0.274, 0.593, 0.336, 0.261, 0.223]

**Take actions**

In [None]:
def pick_actions(problem, tholds, model1, model2, model3, verbose):

    # 0: subtour elimination
    # 1: blossoms (basic combs)
    # 2: advanced comb
    # 3: clique tree
    # 4: blossom + path 
    # 5: bipartition
    # 6: envelope (NOT APPLICABLE)
    # 7: crown 8 
    # 8: crown multiple 
    
    graph1 = problem.graph 
    graph1.add_edges_from(zip(graph1.nodes(), graph1.nodes()))
    graph1 = DGLGraph(graph1)

    pred1, g_emb1, n_emb1 = model1(graph1)
    pred2, g_emb2, n_emb2 = model2(graph1)
    pred3, g_emb3, n_emb3 = model3(graph1)

    counter0 = 0 
    counter1 = 0 
    counter2 = 0 
    counter3 = 0 
    counter4 = 0 
    counter5 = 0 
    counter7 = 0 
    counter8 = 0 
    cycle_cutoff = 0.99 
    teeth_cutoff = 0.1
    teeth_cutoff_path = 1
    connection_check_flag = 1 # set this as 1 in order to avoid redundant constraints 
    false_alarm = 0
    true_alarm = 0

    # 0: subtour elimination (GCN1 - GCN3)
    action0 = 1 if (pred1[0][0].detach().item() + pred3[0][0].detach().item())/2 >= tholds[0] else 0 
    if action0 == 1: # go find isolated islands and add subtours 
        counter0 = problem.subtour_elimn()
        if counter0 > 0:
          true_alarm = true_alarm + 1
          if verbose == 1:
            print('yes, there were subtours indeed:', counter0)
        else: 
          false_alarm = false_alarm + 1
          if verbose == 1:
            print('nope, there were NO subtours, wrong prediction')

    # 1: basic blossom inequalities (GCN3)
    action1 = 1 if pred3[0][1].detach().item() >= tholds[1] else 0 
    if action1 == 1:
        counter1 = problem.find_multi_blossoms(cycle_cutoff, teeth_cutoff)
        if counter1 > 0:
          true_alarm = true_alarm + 1
          if verbose == 1:
            print('yes, there were blossoms indeed:', counter1)
        else: 
          false_alarm = false_alarm + 1
          if verbose == 1:
            print('nope, there were NO blossoms, wrong prediction')

    # 2: advanced comb inequalities (GCN3)
    action2 = 1 if pred3[0][2].detach().item() >= tholds[2] else 0 
    if action2 == 1: 
        counter2 = problem.find_adv_combs(cycle_cutoff, teeth_cutoff)
        if counter2 > 0:
          true_alarm = true_alarm + 1
          if verbose == 1:
            print('yes, there were advanced combs indeed:', counter2)
        else: 
          false_alarm = false_alarm + 1
          if verbose == 1:
            print('nope, there were NO advanced combs, wrong prediction')

    # 3: clique-tree inequality (GCN 1-2-3)
    action3 = 1 if (pred1[0][3].detach().item() + pred2[0][3].detach().item() + pred3[0][3].detach().item())/3 >= tholds[3] else 0 
    if action3 == 1: 
        counter3 = problem.find_clique_tree_2(cycle_cutoff, teeth_cutoff) 
        if counter3 > 0:
          counter3 = 1
          true_alarm = true_alarm + 1
          if verbose == 1:
            print('yes, there was a clique-tree indeed')
        else: 
          false_alarm = false_alarm + 1
          if verbose == 1:
            print('nope, there was NO clique-tree, wrong prediction')
        
    # 5: bipartiton inequality (GCN1 - GCN2)
    action5 = 1 if (pred1[0][5].detach().item() + pred2[0][5].detach().item())/2 >= tholds[5] else 0 
    if action5 == 1:
        counter5 = problem.find_bipartition(cycle_cutoff, teeth_cutoff) 
        if counter5 > 0:
          counter5 = 1
          true_alarm = true_alarm + 1
          if verbose == 1:
            print('yes, there was a bipartiton indeed')
        else: 
          false_alarm = false_alarm + 1
          if verbose == 1:
            print('nope, there was NO bipartition, wrong prediction')

    # 4: blossom and path inequality (GCN1 - GCN3)
    action4 = 1 if (pred1[0][4].detach().item() + pred3[0][4].detach().item())/2 >= tholds[4] else 0 
    if action4 == 1:
        counter4, xd1 = problem.find_blossom_n_path(cycle_cutoff, teeth_cutoff, teeth_cutoff_path) 
        if counter4 > 1:
          true_alarm = true_alarm + 1
          if verbose == 1:
            print('yes, there were blossom and following path indeed')
        elif counter4 == 1:
          true_alarm = true_alarm + 1
          if verbose == 1:
            print('hmm, there was only a blossom here')
        else: 
          false_alarm = false_alarm + 1
          if verbose == 1:
            print('nope, there were NO blossom and path, wrong prediction')

    # 7: crown with 8 subsets (GCN1 - GCN3)
    action7 = 1 if (pred1[0][6].detach().item() + pred3[0][6].detach().item())/2 >= tholds[6] else 0 
    if action7 == 1:
        counter7 = problem.find_crown_8(connection_check_flag)
        if counter7 > 1:
          true_alarm = true_alarm + 1
          counter7 = 1
          if verbose == 1:
            print('yes, there was a crown with 8 subsets indeed')
    
    # 8: crown with N subsets (GCN1 - GCN3)
    action8 = 1 if (pred1[0][7].detach().item() + pred3[0][7].detach().item())/2 >=  tholds[7] else 0 
    if action8 == 1: # go find advanced combs
        counter8 = problem.find_crown_more(connection_check_flag)
        if counter8 > 1:
          true_alarm = true_alarm + 1
          counter8 = 1
          if verbose == 1:
            print('yes, there was a crown with N subsets indeed')

    count_total = counter0 + counter1 + counter2 + counter3 + counter4 + counter5 + counter7 + counter8

    return problem, count_total, true_alarm, false_alarm

In [None]:
def plot_graph(problem1):
    
    # load nodes
    graph = nx.Graph()
    graph.add_nodes_from(problem1.node_list)
    
    if len(problem1.X_soln) > 0:
      
        for k in range(len(problem1.X_soln)):
            graph.add_edge(problem1.X_soln.loc[k][0], problem1.X_soln.loc[k][1], capacity=1.0)
         
        problem1.X_soln['color'] = problem1.X_soln.apply(lambda x: 'b' if x.x_value >= 0.99  else 'r', axis=1)
        colors = problem1.X_soln['color'].to_list()
    
        #set coordinates and plot
        nx.set_node_attributes(graph, problem1.coordinates_dict, 'pos')    
        nx.draw(graph, nx.get_node_attributes(graph, 'pos'), with_labels = True, edge_color=colors)
        #plt.savefig("simple_path.png") 
        plt.show() 

**Qnet agent**

In [None]:
class QTSP:
    def __init__(self, input_dim, output_dim, discount_rate, epochs):
        # initialize atributes
        self._input_size = input_dim
        self._action_size = output_dim
        self.discount_rate = discount_rate 
        self.epochs = epochs
        #self.loss_func = nn.MSELoss()
        
        # memory to keep state-action-reward-next state
        self.memo = deque(maxlen=50000)

        # build networks
        #WARNINGG
        self.q_network = self.build_compile_model() # q network to predict q values
        self.optimizer = optim.RMSprop(self.q_network.parameters())
        self.target_network = self.build_compile_model() # target network as frozen version of q network
        self.align_target_model() # align weights
        self.target_network.eval()
    
    def store(self, state, action, reward, next_state, complete_flag):
        self.memo.append((state, action, reward, next_state, complete_flag))
    
    def build_compile_model(self):
        model = gm.NNN_(self._input_size, self._action_size, 0.1)
        return model

    def align_target_model(self):
        # align weights
        self.target_network.load_state_dict(self.q_network.state_dict())
    
    def act(self, state, exploration_rate):
        if np.random.rand() <= exploration_rate:
            return random.randint(0, self._action_size-1)
        else:
            q_values = self.q_network(state) # predict state (graph)
            print("this action is QMAX:", np.argmax(q_values.detach().numpy()))
            return np.argmax(q_values.detach().numpy())

    def retrain(self, batch_size, loss_list):
        minibatch = random.sample(self.memo, batch_size)
        print("training now")
        self.q_network.train()

        for epoch in range(self.epochs):
            for state, action, reward, next_state, complete_flag in minibatch:
                
                target = self.q_network(state)
        
                if complete_flag == 1:
                    target[action] = reward
                else:
                    # predict next state's q-value via target (2nd) network 
                    t = self.target_network(next_state)
                    target[action] = reward + self.discount_rate * np.amax(t.detach().numpy())
                
                prediction = self.q_network(state)
                loss = F.smooth_l1_loss(prediction, target)
                loss_list.append(loss.detach().numpy())
                self.optimizer.zero_grad() 
                loss.backward(retain_graph=True)
                self.optimizer.step()
        return loss_list 

**Functions for reward, branch-check and state**

In [None]:
def check_branch(problem):

  soln = problem.X_soln 
  fractional = soln[soln.x_value < 1]
  if len(fractional) > 0: 
    branch_flag = 1
  else:
    branch_flag = 0

  return branch_flag

# define state
def define_state(problem, model1, model2):
  
  graph1 = problem.graph 
  graph1.add_edges_from(zip(graph1.nodes(), graph1.nodes()))
  graph1 = DGLGraph(graph1)

  h = pd.DataFrame(graph1.in_degrees().detach().numpy()).reset_index(drop=False) 
  h = list(h[h[0] == h[0].max()]['index'])
  
  pred1, g_emb1, n_emb1 = model1(graph1)
  pred2, g_emb2, n_emb2 = model2(graph1)

  soln = problem.X_soln 
  fractional = soln[soln.x_value < 1]
  list0 = list(set(list(fractional['origin'].append(fractional['destination']))))

  h2 = [item for item in h if item in list0]

  if len(fractional) > 0:
    if len(h2) > 0:
      node = h2[0]
      list1 = list(fractional[fractional['origin'] == node]['destination'].append(fractional[fractional['destination'] == node]['origin']))
      if len(list1) > 0:
        connection = list1[0]
        regret = [item for item in list1 if item != connection]  
        if len(regret) > 0:
          regret = regret[0] 
        else:
          regret = -1
      else:
        connection = -1
        regret = -1
    else: 
      node = fractional['origin'].iloc[0]
      connection = fractional['destination'].iloc[0]
      list1 = list(fractional[fractional['origin'] == node]['destination'].append(fractional[fractional['destination'] == node]['origin']))
      regret = [item for item in list1 if item != connection]
      if len(regret) > 0:
       regret = regret[0] 
      else:
       regret = -1
  
  if connection < 0 or regret < 0:
    node = soln['origin'].iloc[0] #pick the first line
    connection = soln['destination'].iloc[0]
    list1 = list(soln[soln['origin'] == node]['destination'].append(soln[soln['destination'] == node]['origin']))
    regret = [item for item in list1 if item != connection]
    if len(regret) > 0:
      regret = regret[0] 
    else:
      regret = connection
  
  xd1 = n_emb1[node].view(-1, 1).float() + n_emb1[connection].view(-1, 1).float()
  xd2 = n_emb1[node].view(-1, 1).float() + n_emb1[regret].view(-1, 1).float()

  xd3 = n_emb2[node].view(-1, 1).float() + n_emb2[connection].view(-1, 1).float()
  xd4 = n_emb2[node].view(-1, 1).float() + n_emb2[regret].view(-1, 1).float() 

  state = th.cat((xd1, xd3, xd2, xd4), 0)
  state = state.view(-1)

  return state, node, connection, regret

def add_valid_ineq(problem1, constraint_count):
  
  old_solution = problem1.X_soln[['origin', 'destination', 'x_value']]
  #problem1, count_total, true_alarm, false_alarm = pick_actions(problem1, tholds, GCN1, GCN2, GCN3, 0)     
  # ONLY SUBTOUR ELIMINATION CONSTRAINTS ADDED FOR SIMPLICITY
  count_total = problem1.subtour_elimn()

  problem1.solve_lp_relax()
  new_solution = problem1.X_soln[['origin', 'destination', 'x_value']]
  change_flag = 1 - int(old_solution.equals(new_solution))

  if change_flag == 1:
    constraint_count = constraint_count + count_total 
    problem1.graph = problem1.create_graph()
    problem1.check_if_complete()

  return problem1, change_flag, constraint_count


def check_regret(problem, connection, regret):
  xd1 = problem.X_soln
  xd1 = xd1[xd1['x_value'] == 1]
  check1 = xd1[(xd1['origin'] == connection) & (xd1['destination'] == regret)]
  check2 = xd1[(xd1['origin'] == regret) & (xd1['destination'] == connection)]
  if len(check1) + len(check2) >= 1:
    check_regret = 1
  else:
    check_regret = 0

  return check_regret  


**Train/Test instances**

**TSP list**

In [None]:
test_tsp_list = [
(38,'TestTSP0_3.csv'),
(734,'TestTSP14_3.csv'),
(839,'TestTSP16_9.csv'),
(844,'TestTSP16_14.csv'),
(183,'TestTSP3_14.csv'),
(445,'TestTSP8_12.csv'),
(791,'TestTSP15_11.csv'),
(346,'TestTSP6_13.csv'),
(495,'TestTSP9_12.csv'),
(944,'TestTSP18_14.csv'),
(90,'TestTSP1_5.csv'),
(285,'TestTSP5_2.csv'),
(148,'TestTSP2_13.csv'),
(136,'TestTSP2_1.csv'),
(546,'TestTSP10_13.csv'),
(841,'TestTSP16_11.csv'),
(41,'TestTSP0_6.csv'),
(438,'TestTSP8_5.csv'),
(433,'TestTSP8_0.csv'),
(883,'TestTSP17_3.csv'),
(42,'TestTSP0_7.csv'),
(442,'TestTSP8_9.csv'),
(94,'TestTSP1_9.csv'),
(44,'TestTSP0_9.csv'),
(397,'TestTSP7_14.csv'),
(43,'TestTSP0_8.csv'),
(142,'TestTSP2_7.csv'),
(287,'TestTSP5_4.csv'),
(788,'TestTSP15_8.csv'),
(98,'TestTSP1_13.csv'),
(46,'TestTSP0_11.csv'),
(780,'TestTSP15_0.csv'),
(87,'TestTSP1_2.csv'),
(689,'TestTSP13_8.csv'),
(283,'TestTSP5_0.csv'),
(632,'TestTSP12_1.csv'),
(292,'TestTSP5_9.csv'),
(742,'TestTSP14_11.csv'),
(288,'TestTSP5_5.csv'),
(91,'TestTSP1_6.csv'),
(96,'TestTSP1_11.csv'),
(681,'TestTSP13_0.csv'),
(297,'TestTSP5_14.csv'),
(888,'TestTSP17_8.csv'),
(793,'TestTSP15_13.csv'),
(685,'TestTSP13_4.csv'),
(89,'TestTSP1_4.csv'),
(185,'TestTSP3_0.csv'),
(145,'TestTSP2_10.csv'),
(439,'TestTSP8_6.csv'),
(785,'TestTSP15_5.csv'),
(694,'TestTSP13_13.csv'),
(489,'TestTSP9_6.csv'),
(294,'TestTSP5_11.csv'),
(545,'TestTSP10_12.csv'),
(97,'TestTSP1_12.csv'),
(836,'TestTSP16_6.csv'),
(497,'TestTSP9_14.csv'),
(144,'TestTSP2_9.csv'),
(49,'TestTSP0_14.csv'),
(641,'TestTSP12_10.csv'),
(540,'TestTSP10_7.csv'),
(149,'TestTSP2_14.csv'),
(88,'TestTSP1_3.csv'),
(47,'TestTSP0_12.csv'),
(486,'TestTSP9_3.csv'),
(392,'TestTSP7_9.csv'),
(880,'TestTSP17_0.csv'),
(289,'TestTSP5_6.csv'),
(99,'TestTSP1_14.csv'),
(387,'TestTSP7_4.csv'),
(892,'TestTSP17_12.csv'),
(835,'TestTSP16_5.csv'),
(93,'TestTSP1_8.csv'),
(881,'TestTSP17_1.csv'),
(447,'TestTSP8_14.csv'),
(39,'TestTSP0_4.csv'),
(284,'TestTSP5_1.csv'),
(147,'TestTSP2_12.csv'),
(890,'TestTSP17_10.csv'),
(394,'TestTSP7_11.csv'),
(635,'TestTSP12_4.csv'),
(488,'TestTSP9_5.csv'),
(891,'TestTSP17_11.csv'),
(781,'TestTSP15_1.csv'),
(542,'TestTSP10_9.csv'),
(195,'TestTSP3_10.csv'),
(286,'TestTSP5_3.csv'),
(547,'TestTSP10_14.csv'),
(842,'TestTSP16_12.csv'),
(636,'TestTSP12_5.csv'),
(40,'TestTSP0_5.csv'),
(48,'TestTSP0_13.csv'),
(541,'TestTSP10_8.csv'),
(234,'TestTSP4_1.csv'),
(631,'TestTSP12_0.csv'),
(534,'TestTSP10_1.csv'),
(645,'TestTSP12_14.csv'),
(838,'TestTSP16_8.csv'),
(692,'TestTSP13_11.csv'),
(792,'TestTSP15_12.csv'),
(336,'TestTSP6_3.csv'),
(634,'TestTSP12_3.csv'),
(684,'TestTSP13_3.csv'),
(494,'TestTSP9_11.csv'),
(396,'TestTSP7_13.csv'),
(831,'TestTSP16_1.csv'),
(693,'TestTSP13_12.csv'),
(395,'TestTSP7_12.csv'),
(485,'TestTSP9_2.csv'),
(490,'TestTSP9_7.csv'),
(492,'TestTSP9_9.csv'),
(943,'TestTSP18_13.csv'),
(388,'TestTSP7_5.csv'),
(390,'TestTSP7_7.csv'),
(942,'TestTSP18_12.csv'),
(389,'TestTSP7_6.csv'),
(391,'TestTSP7_8.csv'),
(384,'TestTSP7_1.csv'),
(687,'TestTSP13_6.csv'),
(335,'TestTSP6_2.csv'),
(535,'TestTSP10_2.csv'),
(637,'TestTSP12_6.csv'),
(837,'TestTSP16_7.csv'),
(794,'TestTSP15_14.csv'),
(932,'TestTSP18_2.csv'),
(639,'TestTSP12_8.csv'),
(940,'TestTSP18_10.csv'),
(939,'TestTSP18_9.csv'),
(533,'TestTSP10_0.csv'),
(386,'TestTSP7_3.csv'),
(935,'TestTSP18_5.csv'),
(782,'TestTSP15_2.csv'),
(245,'TestTSP4_12.csv'),
(236,'TestTSP4_3.csv'),
(536,'TestTSP10_3.csv'),
(936,'TestTSP18_6.csv'),
(247,'TestTSP4_14.csv'),
(537,'TestTSP10_4.csv'),
(938,'TestTSP18_8.csv'),
(241,'TestTSP4_8.csv'),
(383,'TestTSP7_0.csv'),
(238,'TestTSP4_5.csv'),
(233,'TestTSP4_0.csv'),
(931,'TestTSP18_1.csv'),
(235,'TestTSP4_2.csv'),
(588,'TestTSP11_5.csv'),
(240,'TestTSP4_7.csv'),
(242,'TestTSP4_9.csv'),
(244,'TestTSP4_11.csv'),
(591,'TestTSP11_8.csv'),
(237,'TestTSP4_4.csv'),
(385,'TestTSP7_2.csv'),
(583,'TestTSP11_0.csv'),
(393,'TestTSP7_10.csv'),
(246,'TestTSP4_13.csv'),
(587,'TestTSP11_4.csv'),
(592,'TestTSP11_9.csv'),
(586,'TestTSP11_3.csv'),
(584,'TestTSP11_1.csv'),
(239,'TestTSP4_6.csv'),
(590,'TestTSP11_7.csv'),
(243,'TestTSP4_10.csv'),
(783,'TestTSP15_3.csv')
]

In [None]:
print('number of instances:', len(test_tsp_list))

number of instances: 164


# Q-Network and Cutting Planes together

**Define the agent**

In [None]:
action_size = 6
input_size = 128 * 2 + 256*2
batch_size = 64

loss_list = []
cp_qnet_test = pd.DataFrame()
agent = QTSP(input_size, action_size, 0.99, 3) 
MODEL_TAG = 'JUN18'

**Load trained agent**

In [None]:
agent.q_network.load_state_dict(th.load("GCN_QNET_"+MODEL_TAG+".pth.tar"))
#agent.target_network.load_state_dict(th.load("GCN_TARGETNET_"+MODEL_TAG+".pth.tar"))
agent.q_network.eval()

NNN_(
  (layer1): Linear(in_features=768, out_features=768, bias=True)
  (drop_layer1): Dropout(p=0.1, inplace=False)
  (layer2): Linear(in_features=768, out_features=384, bias=True)
  (drop_layer2): Dropout(p=0.1, inplace=False)
  (layer3): Linear(in_features=384, out_features=192, bias=True)
  (drop_layer3): Dropout(p=0.1, inplace=False)
  (layer4): Linear(in_features=192, out_features=6, bias=True)
)

# Solve Test instances

In [None]:
for k in range(81, len(test_tsp_list)): 

  tsp_id = test_tsp_list[k][0]
  data_name = test_tsp_list[k][1]
  problem1 = act.init_prob(tsp_id, data_name, 0, 'continuous') 
  initial_objective = problem1.objective_val
  
  # show the initial solution
  #print('initial tour')
  #plot_graph(problem1)
  constraint_count = 0
  branch_count = 0 
  action_sum = 0

  true_alarm_count = 0
  false_alarm_count = 0
  con_limit = 500
  change_flag = 1
  break_flag = 0 

  start_time = datetime.now()
  print("iteration:", k)
  print("running for tsp:", tsp_id, "start time:", start_time)

  while (constraint_count + branch_count) <= con_limit and problem1.complete_flag == 0:

    if break_flag == 1:
      break # while loop
  
    if change_flag == 0 or (constraint_count + branch_count) > con_limit/5: # keep branching either after some time as well
      # Do branching if it's possible 
      branch_flag = check_branch(problem1)
      
      if branch_flag == 1: 
        branch_count = branch_count + 1
        state, node, connection, regret = define_state(problem1, GAT3, GCN1)
        q_values = agent.q_network(state)
        print('q values are:', q_values.detach().numpy())
        #plot_graph(problem1)
        action = np.argmax(q_values.detach().numpy())
        if action > 3: 
          c_regret = check_regret(problem1, connection, regret)
          if c_regret == 1: 
            action = np.argmax(q_values.detach().numpy()[0:4])
                    
        action_sum = action_sum + action
        print('branching now, action is:', action, node, connection, regret)
      
        # take the action    
        problem1 = act.add_branch(problem1, node, connection, regret, action)
        problem1.solve_lp_relax()
        problem1.graph = problem1.create_graph()
        problem1.check_if_complete()

      elif branch_flag == 0 and change_flag == 0:
        break_flag = 1 # no more rooms to search through 
   
    # Also keep adding valid inequalities
    old_solution = problem1.X_soln[['origin', 'destination', 'x_value']]
    problem1, count_total, true_alarm, false_alarm = pick_actions(problem1, tholds, GCN1, GCN2, GCN3, 0)

    problem1.solve_lp_relax() 
    new_solution = problem1.X_soln[['origin', 'destination', 'x_value']]
    change_flag = 1 - int(old_solution.equals(new_solution))
    constraint_count = constraint_count + count_total    

    if change_flag == 1:
      true_alarm_count = true_alarm_count + true_alarm
      false_alarm_count = false_alarm_count + false_alarm
      problem1.graph = problem1.create_graph()
      problem1.check_if_complete()
    
    elif change_flag == 0:
      false_alarm_count = false_alarm_count + constraint_count 

  end_time = datetime.now()
  #print('final tour')
  #plot_graph(problem1)
  problem1.check_if_complete()
  if problem1.complete_flag == 0:
    print('tour is NOT complete, branching needed!')
  else: 
    print('tour is complete!')
    print('objective_val', problem1.objective_val)

  obj1 = {'tsp_id': tsp_id, 
            'initial_objective': initial_objective, 
            'final_objective' : problem1.objective_val, 
            'start_time': start_time, 
            'end_time': end_time, 
            'complete_flag':problem1.complete_flag, 
            'n_of_constraint': constraint_count, 
            'n_of_branch': branch_count,
            'action_sum':action_sum,
            'true_alarm': true_alarm_count, 
            'false_alarm': false_alarm_count,
            'con_limit': con_limit
          }
  cp_qnet_test = cp_qnet_test.append(pd.DataFrame(obj1, index=[0]))

# download the output
cp_qnet_test.to_csv('CPQ_test_perfo.csv')
files.download('CPQ_test_perfo.csv') 

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
q values are: [20.194447 15.178723 23.986702 10.857274 19.041786 21.481552]
branching now, action is: 2 41 42 43
q values are: [20.699461 15.36964  23.385155 11.609417 19.29582  22.048775]
branching now, action is: 2 62 69 53
q values are: [19.020584 14.109396 21.496956 12.636678 17.615341 22.176216]
branching now, action is: 5 19 22 35
q values are: [22.344357 15.886289 25.913177 11.297845 20.609398 24.085682]
branching now, action is: 2 99 100 101
q values are: [22.14153  15.736742 24.2262   12.425621 20.247892 24.0858  ]
branching now, action is: 2 101 103 104
q values are: [20.633287 12.004335 20.805576 14.430344 16.032534 22.449455]
branching now, action is: 5 42 28 29
tour is NOT complete, branching needed!
iteration: 91
running for tsp: 40 start time: 2020-06-19 19:48:23.086221
q values are: [22.218452 15.972099 24.990059 11.734425 20.405546 23.750748]
branching now, action is: 2 63 85 15
q values are: [20.393387 1

In [None]:
# download the output
cp_qnet_test.to_csv('CPQ_test_perfo.csv')
files.download('CPQ_test_perfo.csv') 