-
Notifications
You must be signed in to change notification settings - Fork 0
/
solver_b.py
109 lines (97 loc) · 3.33 KB
/
solver_b.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import matplotlib.pyplot as plt
from time import time
from copy import copy
from time import time
from database import Database
from tqdm import tqdm
import math
def score_func(s1,s2):
ls1 = len(s1)
ls2 = len(s2)
lsi = len(s1.intersection(s2))
return min(ls1-lsi, ls2-lsi, lsi)
def union(s1,s2):
return len(s1.union(s2))
def inter(s1,s2):
return len(s1.intersection(s2))
class experiment():
iter_tracker = 100
def __init__(self, reach2_pen, reach3_pen):
self.db = Database(1)
self.dict_hyper = {}
self.dict_hyper['reach2_pen'] = reach2_pen
self.dict_hyper['reach3_pen'] = reach3_pen
def run(self, test=False):
reach2_pen = self.dict_hyper['reach2_pen']
reach3_pen = self.dict_hyper['reach3_pen']
name = self._name_gen()
print(name)
self.db.gen_edges()
self.db.init_slides(start_ind = 0)
self.dic = self.db.dict_id_cons
self.diu = self.db.dict_id_used
last_id = self.db.get_last_id()
self.list_ids = [last_id]
self.db.pop_cons_by_id(last_id)
self.num_no_con = 0
self.iterations = self.db.slides_length-1
if test:
self.iterations = 101
for iteration in tqdm(range(self.iterations)):
last_id = self.list_ids[-1]
if iteration % 1000 == 0:
pass
cons = list(self.diu[last_id])
if cons:
scores = []
for con in cons:
con_cons = self.dic[con]
score = -len(con_cons)
reach2 = []
reach3 = []
for con_con in con_cons:
reach2 += list(self.dic[con_con])
con_con_cons = self.dic[con_con]
for con_con_con in con_con_cons:
reach3 += list(self.dic[con_con_con])
score += reach2_pen * len(set(reach2)) + reach3_pen * len(set(reach3))
scores.append(score)
ind = np.argmax(scores)
id = cons[ind]
if not cons:
id = self.start_new_chain()
self.db.pop_cons_by_id(id)
self.list_ids.append(id)
self.db.set_slides(self.list_ids)
score = self.db.score_slides()
name = self._name_gen()
print(name)
print('score: ',score)
path_sol = 'solution_b.txt'
self.db.gen_output(path = path_sol)
return score
def start_new_chain(self):
reach2_pen = self.dict_hyper['reach2_pen']
reach3_pen = self.dict_hyper['reach3_pen']
self.num_no_con += 1
remaining = list(self.dic.keys())
scores = []
for con in remaining:
con_cons = self.dic[con]
score = -len(con_cons)
reach2 = []
reach3 = []
for con_con in con_cons:
reach2 += list(self.dic[con_con])
score += reach2_pen * len(set(reach2))
scores.append(score)
ind = np.argmax(scores)
return remaining[ind]
def _name_gen(self):
string = ''
for key, val in self.dict_hyper.items():
string += str(key)+':_'+str(val)+',_'
return string[:-2]
e = experiment(0.001, -0.000001)
e.run()