In [522]:
import random
import numpy as np
from scipy.optimize import linprog

# Generate a random assignment problem

In [523]:
n_elements = 10            
n_added_elements = 5        

# Generate old clusters
N = 3
elements = range(n_elements)
random.shuffle(elements)
old = [set(elements[i::N]) for i in range(N)]

# Generate new clusters
M = 4
elements = range(n_elements + n_added_elements)
random.shuffle(elements)
new = [set(elements[i::M]) for i in range(M)]

# Add M virtual empty old clusters 
N = N + M
for j in range(M):
    old.append(set([]))

print old
print new

[set([8, 1, 5, 0]), set([2, 4, 6]), set([9, 3, 7]), set([]), set([]), set([]), set([])]
[set([9, 3, 14, 1]), set([10, 4, 13, 6]), set([8, 11, 12, 5]), set([0, 2, 7])]


# Fixed problem

In [547]:
N = 2
M = 2
old = [set([1]), set([2, 3, 4, 5])]
new = [set([2]), set([1, 3, 4, 5])]
N = N + M
for j in range(M):
    old.append(set([]))
print old
print new

[set([1]), set([2, 3, 4, 5]), set([]), set([])]
[set([2]), set([1, 3, 4, 5])]


# Solve using linear programming

In [548]:
cost = np.zeros((N, M))

for i, s_i in enumerate(old):
    for j, s_j in enumerate(new):
        cost[i, j] = -len(s_i.intersection(s_j)) -1. / (1 + len(s_i.symmetric_difference(s_j)))
        
print cost

[[-0.33333333 -1.25      ]
 [-1.25       -3.33333333]
 [-0.5        -0.2       ]
 [-0.5        -0.2       ]]


In [551]:
n_edges = N * M
c = cost.ravel()

# Each old cluster may be assigned to at most one new cluster
A_ub = np.zeros((N, n_edges))
for i in range(N):
    A_ub[i, i*M:(i+1)*M] = 1.0
b_ub = np.ones(N)

# Each new cluster has to be assigned to exactly one old cluster
A_eq = np.zeros((M, n_edges))
for j in range(M):
    A_eq[j, j::M] = 1.0
b_eq = np.ones(M)

r = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, options={"maxiter": 5000})
print r

  status: 0
   slack: array([ 1.,  0.,  0.,  1.])
 success: True
     fun: -3.8333333333333339
       x: array([ 0.,  0.,  0.,  1.,  1.,  0.,  0.,  0.])
 message: 'Optimization terminated successfully.'
     nit: 6


In [552]:
for i in range(N):
    for j in range(M):
        if r.x[i*M+j] == 1:
            print "%15s: %30s -> %30s" % (cost[i, j], sorted(old[i]), sorted(new[j]))
            break
    else:
        print "%15s: %30s -> %30s" % (0, sorted(old[i]), "/")
        
print "TOTAL = %15s" % r.fun

              0:                            [1] ->                              /
 -3.33333333333:                   [2, 3, 4, 5] ->                   [1, 3, 4, 5]
           -0.5:                             [] ->                            [2]
              0:                             [] ->                              /
TOTAL =  -3.83333333333
