Skip to content

Commit

Permalink
after the flake8 linter
Browse files Browse the repository at this point in the history
  • Loading branch information
gokhankici committed Jun 7, 2019
1 parent 8be729c commit cba7afe
Showing 1 changed file with 63 additions and 57 deletions.
120 changes: 63 additions & 57 deletions scripts/linprog/assumptions.py
Expand Up @@ -2,22 +2,16 @@
# vim: set foldmethod=marker:

import sys
import warnings
import networkx as nx
import collections
import cplex
import subprocess
import time
import itertools
import enum
import json

from flow_capacity import CplexFlowCapSolver
from utils import *
from tests import get_test
from utils import debug, parse_cplex_input, val_to_int

# from networkx.algorithms.components import strongly_connected_components
# import pudb

import pudb

class Variable(collections.namedtuple("Variable", ["node",
"var_index",
Expand All @@ -40,15 +34,17 @@ def __str__(self):
self.var_index,
self.mark_index)

class Assumptions(collections.namedtuple("Assumptions", ["always_eq", "initial_eq"])):

class Assumptions(collections.namedtuple("Assumptions",
["always_eq", "initial_eq"])):
"""
always_eq are the node ids of the variables that need the always equal assumption.
initial_eq is the same thing for initially equal assumption.
always_eq are the node ids of the variables that need the always equal
assumption. initial_eq is the same thing for initially equal assumption.
"""
def json_dump(self):
""" names are the mapping from node ids to variable names """
j = { "always_eq" : list(v.name for v in self.always_eq),
"initial_eq" : list(v.name for v in self.initial_eq) }
j = {"always_eq": list(v.name for v in self.always_eq),
"initial_eq": list(v.name for v in self.initial_eq)}
return json.dumps(j, indent=2)

def print(self, **kwargs):
Expand All @@ -57,6 +53,7 @@ def print(self, **kwargs):
for v in self.initial_eq:
print("// @annot{{sanitize({})}}".format(v.name), **kwargs)


class AssumptionSolver:
def __init__(self, filename):
"""
Expand All @@ -65,7 +62,9 @@ def __init__(self, filename):
parsed = parse_cplex_input(filename)
self.g = parsed["graph"]

cannot_mark_eq = set(parsed["cannot_mark_eq"]) # nodes we DO NOT want to mark as "always_eq"
# nodes we DO NOT want to mark as "always_eq"
cannot_mark_eq = set(parsed["cannot_mark_eq"])

def is_node_markable(n):
return n not in cannot_mark_eq

Expand All @@ -74,15 +73,16 @@ def is_node_markable(n):

# variables used in the linear problem
# mapping from variable identifier to the variable itself
self.variables = { v : Variable(node = v,
name = parsed["names"][v],
var_index = i,
mark_index = i + node_cnt,
is_markable = is_node_markable(v),
is_register = parsed["is_reg"][v])
for i, v in enumerate(self.g.nodes()) }
self.variables = {v: Variable(node=v,
name=parsed["names"][v],
var_index=i,
mark_index=i + node_cnt,
is_markable=is_node_markable(v),
is_register=parsed["is_reg"][v])
for i, v in enumerate(self.g.nodes())}

self.must_eq = set(self.variables[n] for n in parsed["must_eq"]) # nodes we DO want to be "always_eq"
# nodes we DO want to be "always_eq"
self.must_eq = set(self.variables[n] for n in parsed["must_eq"])

# calculate costs of the nodes
# node_costs : Node -> Int
Expand All @@ -97,13 +97,16 @@ def calc_costs(self):
"""
Returns a mapping from node ids to their costs
"""
costs = collections.defaultdict(int)
worklist = collections.deque( n for n in self.g.nodes if len(self.g.pred[n]) == 0 )
done = set(worklist)
costs = collections.defaultdict(int)
worklist = collections.deque(n
for n in self.g.nodes
if len(self.g.pred[n]) == 0)
done = set(worklist)

while worklist:
v = worklist.popleft()
costs[v] = max((costs[u] for u in self.g.predecessors(v)), default=0) + 1
costs[v] = max((costs[u] for u in self.g.predecessors(v)),
default=0) + 1
for u in self.g.successors(v):
if u not in done:
worklist.append(u)
Expand All @@ -121,40 +124,39 @@ def add_always_eq_constraints(self, prob):
b_m is the mark variable o b
"""
for node in self.g.nodes:
var = self.variables[node]
parents = [ self.variables[u] for u in self.g.predecessors(node) ]
c = max(len(parents), 1)
i = var.var_index
indices = [ i ]
coefficients = [ c ]
var = self.variables[node]
parents = [self.variables[u] for u in self.g.predecessors(node)]
c = max(len(parents), 1)
i = var.var_index
indices = [i]
coefficients = [c]

if var.is_markable:
mi = var.mark_index
indices.append(mi)
coefficients.append(-c)
le = cplex.SparsePair(ind = [mi, i], val = [1, -1])
prob.linear_constraints.add(lin_expr = [le], senses = "L", rhs = [0])
le = cplex.SparsePair(ind=[mi, i], val=[1, -1])
prob.linear_constraints.add(lin_expr=[le], senses="L", rhs=[0])

for p in parents:
debug("{} * {} <- {}".format(c,
var.name,
", ".join(v.name for v in parents)))
debug("{} * {} <- {}".
format(c,
var.name,
", ".join(v.name for v in parents)))
indices.append(p.var_index)
coefficients.append(-1)

le = cplex.SparsePair(ind = indices, val = coefficients)
prob.linear_constraints.add(lin_expr = [le], senses = "L", rhs = [0])
le = cplex.SparsePair(ind=indices, val=coefficients)
prob.linear_constraints.add(lin_expr=[le], senses="L", rhs=[0])

def add_must_eq_constraints(self, prob):
"""
Adds the following constraint for each n in must equal set
n = 1
"""
for var in self.must_eq:
prob.linear_constraints.add(lin_expr = [ cplex.SparsePair(ind = [ var.var_index ],
val = [1] ) ],
senses = "E",
rhs = [1])
le = cplex.SparsePair(ind=[var.var_index], val=[1])
prob.linear_constraints.add(lin_expr=[le], senses="E", rhs=[1])

def get_objective_function(self):
"""
Expand All @@ -165,7 +167,7 @@ def get_objective_function(self):
n = var.node
c = self.node_costs[n]
if var.is_markable:
i = var.mark_index
i = var.mark_index
obj[i] = c
return obj

Expand All @@ -181,19 +183,20 @@ def suggest_assumptions(self):
nodes in must_eq to be always_eq as well.
"""
prob = cplex.Cplex()
prob.set_problem_type(prob.problem_type.MILP) # use Mixed Integer Linear Programming solver
prob.set_results_stream(None) # disable results output
prob.set_log_stream(None) # disable logging output
prob.set_problem_type(prob.problem_type.MILP) # use MILP solver
prob.set_results_stream(None) # disable results output
prob.set_log_stream(None) # disable logging output

# objective is to minimize
prob.objective.set_sense(prob.objective.sense.minimize)

t0 = time.perf_counter() # start the stopwatch
t0 = time.perf_counter() # start the stopwatch

# update cost and upper bound
prob.variables.add(obj = self.get_objective_function(),
ub = self.get_upper_bounds(),
types = [ prob.variables.type.integer ] * self.variable_count)
ts = [prob.variables.type.integer] * self.variable_count
prob.variables.add(obj=self.get_objective_function(),
ub=self.get_upper_bounds(),
types=ts)

self.add_always_eq_constraints(prob)

Expand All @@ -204,16 +207,18 @@ def suggest_assumptions(self):
sol = prob.solution

t1 = time.perf_counter()
print("elapsed time: {} ms".format(int((t1-t0) * 1000)), file=sys.stderr)
print("elapsed time: {} ms".format(int((t1-t0) * 1000)),
file=sys.stderr)

assert(sol.get_method() == sol.method.MIP)
prob.write("assumptions.lp") # log the constraints to a file
prob.write("assumptions.lp") # log the constraints to a file

# check if we have found an optimal solution
if sol.get_status() == sol.status.MIP_optimal:
return self.make_result(sol)
else:
print("linprog failed: {}".format(sol.get_status_string()), file=sys.stderr)
print("linprog failed: {}".format(sol.get_status_string()),
file=sys.stderr)
sys.exit(1)

def is_always_eq(self, sol, var):
Expand All @@ -239,7 +244,8 @@ def is_marked(self, sol, var):

def make_result(self, solution):
"""
Convert the MILP solution to a set of assumptions that Iodine can understand
Convert the MILP solution to a set of assumptions that Iodine can
understand
"""
marked, always_eq, initial_eq = set(), set(), set()
for var in self.variables.values():
Expand All @@ -254,7 +260,7 @@ def make_result(self, solution):
# add it to the flushed set
initial_eq.add(var)

return Assumptions(always_eq = marked, initial_eq = initial_eq)
return Assumptions(always_eq=marked, initial_eq=initial_eq)

def run(self):
debug("Must equal:")
Expand Down

0 comments on commit cba7afe

Please sign in to comment.