In [7]:
from ortools.linear_solver import pywraplp as mip
import numpy.linalg as la
import numpy as np
import networkx as nx

n_frames = 150
T = 5
graph_3d = {}
for t1 in range(n_frames - 1):
    for pid1 in range(2):
        for t2 in range(t1 + 1, min(t1+T, n_frames)):
            for pid2 in range(2):
                ab = (t1, pid1, t2, pid2)
                graph_3d[ab] = 0.9
                
print("# edges", len(graph_3d))


solver = mip.Solver('t', mip.Solver.CBC_MIXED_INTEGER_PROGRAMMING)

pids_per_frame = {}
Tau = {}
costs = {}
for (tA, pidA, tB, pidB), score in graph_3d.items():
    Tau[tA, pidA, tB, pidB] = solver.BoolVar('t[%i,%i,%i,%i]' % (tA, pidA, tB, pidB))
    costs[tA, pidA, tB, pidB] = np.log(score / (1-score))
    if tA not in pids_per_frame:
        pids_per_frame[tA] = set()
    pids_per_frame[tA].add(pidA)
    if tB not in pids_per_frame:
        pids_per_frame[tB] = set()
    pids_per_frame[tB].add(pidB)

Sum = solver.Sum(Tau[edge] * costs[edge] for edge in graph_3d.keys())

for t1 in range(n_frames - 1):
    for t2 in range(t1 + 1,  min(t1+T, n_frames)):
        for pid1 in pids_per_frame[t1]:
            solver.Add(
                solver.Sum(Tau[t1, pid1, t2, pid2] for pid2 in pids_per_frame[t2]\
                           if (t1, pid1, t2, pid2) in Tau) <= 1
            )
        
        for pid2 in pids_per_frame[t2]:
            solver.Add(
                solver.Sum(Tau[t1, pid1, t2, pid2] for pid1 in pids_per_frame[t1]\
                           if (t1, pid1, t2, pid2) in Tau) <= 1
            )

# -- transitivity --
nbr_trans_constraints = 0
for t1 in range(n_frames - 2):
    for t2 in range(t1 + 1,  min(t1+T-1, n_frames-1)):
        for t3 in range(t2 + 1,  min(t1+T, n_frames)):
            for pid1 in pids_per_frame[t1]:
                for pid2 in pids_per_frame[t2]:
                    for pid3 in pids_per_frame[t3]:
                        ab = (t1, pid1, t2, pid2)
                        bc = (t2, pid2, t3, pid3)
                        ac = (t1, pid1, t3, pid3)
                        if not ab in Tau or not bc in Tau or not ac in Tau:
                            continue
                        solver.Add(Tau[ab] + Tau[bc] - 1 <= Tau[ac])
                        solver.Add(Tau[ab] + Tau[ac] - 1 <= Tau[bc])
                        solver.Add(Tau[bc] + Tau[ac] - 1 <= Tau[ab])
                        nbr_trans_constraints += 3


print("# transitivity constraints", nbr_trans_constraints)
            
solver.Maximize(Sum)
RESULT = solver.Solve()
print("Time = ", solver.WallTime(), " ms")
print("result:", RESULT)
print('\nTotal cost:', solver.Objective().Value())

node_lookup = {}  # t, pid -> node number
reverse_node_lookup = {}  # node number -> t pid
G = nx.Graph()
nid = 1
for t in range(n_frames):
    for pid in pids_per_frame[t]:
        node_lookup[t, pid] = nid
        reverse_node_lookup[nid] = (t, pid)
        G.add_node(nid, key=(t, pid))
        nid += 1

for (t1, pid1, t2, pid2), v in Tau.items():
    assert t1 < t2
    if v.solution_value() > 0:
        nid1 = node_lookup[t1, pid1]
        nid2 = node_lookup[t2, pid2]
        c = costs[t1, pid1, t2, pid2]
        G.add_edge(nid1, nid2, cost=c)
        
for global_pid, comp in enumerate(nx.connected_components(G)):
    print('\ncomponent ', global_pid)
    for nid in comp:
        node = G.nodes[nid]
        t, local_pid = node['key']
        print(t, local_pid)

# edges 2360
# transitivity constraints 21120
Time =  38216  ms
result: 0

Total cost: 2592.7250012568147

component  0
0 0
1 1
2 0
3 1
4 1
5 1
6 0
7 0
8 1
9 1
10 0
11 0
12 1
13 1
14 0
15 1
16 0
17 0
18 1
19 0
20 1
21 1
22 0
23 1
24 0
25 1
26 1
27 0
28 1
29 0
30 0
31 1
32 0
33 1
34 1
35 0
36 1
37 0
38 1
39 1
40 0
41 1
42 0
43 0
44 0
45 1
46 0
47 1
48 1
49 1
50 0
51 0
52 0
53 1
54 1
55 0
56 1
57 0
58 0
59 1
60 0
61 1
62 1
63 0
64 0
65 0
66 0
67 1
68 1
69 1
70 0
71 1
72 0
73 1
74 1
75 0
76 0
77 0
78 1
79 1
80 1
81 0
82 0
83 0
84 1
85 1
86 1
87 0
88 0
89 0
90 1
91 1
92 1
93 1
94 0
95 0
96 1
97 0
98 1
99 1
100 0
101 1
102 0
103 1
104 1
105 0
106 1
107 0
108 0
109 1
110 0
111 1
112 0
113 0
114 1
115 0
116 0
117 1
118 1
119 0
120 0
121 1
122 1
123 0
124 0
125 1
126 1
127 1
128 0
129 0
130 1
131 1
132 1
133 0
134 0
135 1
136 0
137 1
138 1
139 0
140 1
141 0
142 1
143 0
144 0
145 1
146 0
147 1
148 1
149 0

component  1
0 1
1 0
2 1
3 0
4 0
5 0
6 1
7 1
8 0
9 0
10 1
11 1
12 0
13 0
14 1
15 0
16 1
1