## Beam Search Experiments

This notebooks explores preliminary experiments that use beam search together with multi-shot sampling.

In [1]:
import sys
sys.path.append("../../dataset")
sys.path.append("../../model")
sys.path.append("../../")

import pandas as pd
import numpy as np
from sampling import *

from train_bgp import Model as GraphBgpModel
from bgp_semantics import BgpSemantics
from factbase import *
from tqdm import tqdm
import time
import torch
import argparse
from snapshot import ModelSnapshot
import os

class SampleDescriptor:
    def __init__(self, num_nodes, num_networks, program):
        self.num_nodes = num_nodes
        self.num_networks = num_networks
        self.program = program

In [59]:

device = torch.device("cpu") # torch.device('cuda' if torch.cuda.is_available() else 'cpu')

s = BgpSemantics()

predicate_declarations = s.decls()
print(predicate_declarations)
prog = FactBase(predicate_declarations)
feature = prog.feature_registry.feature

excluded_feature_indices = set([1])
features = prog.feature_registry.get_all_features()

model = None
NO_STATIC_ROUTES = True
protocol = "bgp"

model_path = "../../trained-model/bgp-64-pred-6layers-model-epoch2800.pt"
state_dict, HIDDEN_DIM, NUM_EDGE_TYPES, excluded_feature_indices = torch.load(model_path, map_location=device)
model = GraphBgpModel(features, HIDDEN_DIM, NUM_EDGE_TYPES, excluded_feature_indices).to(device)

state_dict = convert_old_gat_conv_state_dict(state_dict)
model.load_state_dict(state_dict)
model.feature = feature
print("using model at", model_path)

print("model iterations", model.num_iterations)

def mask_parameters(x, decls, with_prob_static_route=True, without_static_routes=NO_STATIC_ROUTES):
    mask = torch.zeros_like(x)
    
    # predicate_connected_arg2 [weight]
    mask[:,:,feature("predicate_connected_arg2").idx] = (x[:,:,feature("predicate_connected_arg2").idx] > -1)
    
    # bgp_route: gateway, network, LP, AS, OT, MED, IS_EBGP, SPEAKER_ID
    # predicate_bgp_route_arg2 [LP]
    # predicate_bgp_route_arg3 [AS]
    # predicate_bgp_route_arg4 [OT]
    # predicate_bgp_route_arg5 [MED], 
    # predicate_bgp_route_arg6 [IS_EBGP]
    # predicate_bgp_route_arg7 [SPEAKER_ID]
    if protocol == "bgp":
        masked_bgp_route_args = [2,3,5]
        for i in masked_bgp_route_args:
            idx = feature("predicate_bgp_route_arg"+str(i)).idx
            mask[:,:,idx] = (x[:,:,idx] > -1)

    return mask.bool()

def sample_random_prediction(model, prediction_features, batched_data, mask):
    r = torch.zeros_like(batched_data.x)
    for f in prediction_features:
        r[:,:,f.idx] = torch.randint(0, 32, size=[data.x.size(0), 1]).to(device)
    r[:,:,feature("predicate_bgp_route_arg6").idx] = torch.randint(0, 2, size=[data.x.size(0),1]).to(device)
    
    return mask * r + mask.logical_not() * batched_data.x

{'router': router: Constant, 'network': network: Constant, 'external': external: Constant, 'route_reflector': route_reflector: Constant, 'ibgp': ibgp: Constant × Constant, 'ebgp': ebgp: Constant × Constant, 'bgp_route': bgp_route: Constant × Constant × int × int × int × int × int × int, 'connected': connected: Constant × Constant × int, 'fwd': fwd: Constant × Constant × Constant, 'reachable': reachable: Constant × Constant × Constant, 'trafficIsolation': trafficIsolation: Constant × Constant × Constant × Constant}
using model at ../../trained-model/bgp-64-pred-6layers-model-epoch2800.pt
model iterations 4


In [60]:
import os

def get_id(filename):
    str_n = filename.split("-n", 1)[1].split(".", 1)[0]
    if "-unsatsample" in str_n:
        str_n = str_n.split("-unsatsample", 1)[0]
    return int(str_n)

dataset = "../consistency/dataset-ported/bgp-qlty-reqs-16/"
files = [f for f in os.listdir(dataset) if f.endswith(".logic")]
files = sorted(files, key=lambda x: get_id(x))
print(files)
programs = [torch.load(os.path.join(dataset, f)) for f in files]

['bgp-n0.logic', 'bgp-n1.logic', 'bgp-n2.logic', 'bgp-n3.logic', 'bgp-n4.logic', 'bgp-n5.logic', 'bgp-n6.logic', 'bgp-n7.logic', 'bgp-n8.logic', 'bgp-n9.logic', 'bgp-n10.logic', 'bgp-n11.logic', 'bgp-n12.logic', 'bgp-n13.logic', 'bgp-n14.logic', 'bgp-n15.logic', 'bgp-n16.logic', 'bgp-n17.logic', 'bgp-n18.logic', 'bgp-n19.logic', 'bgp-n20.logic', 'bgp-n21.logic', 'bgp-n22.logic', 'bgp-n23.logic']


In [61]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
from model.beam import beam_search

num_samples = 1
random = True
num_shots = 4
sampling_mode = "topk"

descriptor = programs[0]

if type(descriptor) is Data or type(descriptor) is dict:
    descriptor = SampleDescriptor(0, 0, FactBase.from_data(descriptor))
data, names = descriptor.program.to_torch_data(return_node_names=True)

prediction_features = [
    feature("predicate_connected_arg2"),  # OSPF weights
    # bgp_route: LP x AS x -OT x MED x -IS_EBGP x -SPEAKER_ID
    feature("predicate_bgp_route_arg2"),  # BGP LP
    feature("predicate_bgp_route_arg3"), # BGP AS
    #feature("predicate_bgp_route_arg4"), # BGP ORIGIN_TYPE
    feature("predicate_bgp_route_arg5"), # BGP MED
    #feature("predicate_bgp_route_arg6"), # BGP IS_EBGP
    #feature("predicate_bgp_route_arg7") # SPEAKER_ID
]

batched_data = data.clone().to(device)
batched_data.x = batched_data.x.unsqueeze(1)
batched_data.edge_index = reflexive(bidirectional(batched_data.edge_index), num_nodes=batched_data.x.size(0))
batched_data.edge_type = reflexive_bidirectional_edge_type(batched_data.edge_type, batched_data.x.size(0))
mask = mask_parameters(batched_data.x, predicate_declarations).to(device)

best_consistency = 0

data.x = beam_search(model, prediction_features, batched_data, mask, iterative=True, 
    number_of_shots=num_shots, inverted=False, mode=sampling_mode, beam_n=128, beam_k=8)[:,0]

timeelapsed = time.time() - tstart
predicted_program = FactBase.from_data(data, decls=predicate_declarations, names=names)
consistency, summary = s.check(predicted_program, return_summary=True)
best_consistency = max(consistency, best_consistency)

def get_value(k):
    if k in summary.keys(): return summary[k]
    else: return 1.0

num_nodes = len(descriptor.program.constants("router")) + len(descriptor.program.constants("route_reflector"))
num_networks = len(descriptor.program.constants("network"))

print("Consistency %0.2f (best %0.2f) (Nodes %d, Sample %d)" % (consistency, best_consistency, num_nodes, j))
#if args.num_shots == 1: break