In [4]:
import sys
sys.path.append('../../../python')
import os
import re
from os.path import join as ojoin

In [5]:
from cop_kmeans import cop_kmeans, euclidean_distance

def read_data(datafile):
    data = []
    with open(datafile, 'r') as f:
        for line in f:
            line = line.strip()
            if line != '':
                d = [float(i) for i in line.split()]
                data.append(d)
    return data

def read_constraints(consfile):
    ml, cl = [], []
    with open(consfile, 'r') as f:
        for line in f:
            line = line.strip()
            if line != '':
                line = line.split()
                constraint = (int(line[0]), int(line[1]))
                c = int(line[2])
                if c == 1:
                    ml.append(constraint)
                if c == -1:
                    cl.append(constraint)
    return ml, cl

def cluster_quality(cluster):
    if len(cluster) == 0:
        return 0.0
       
    quality = 0.0
    for i in range(len(cluster)):
        for j in range(i, len(cluster)):
            quality += euclidean_distance(cluster[i], cluster[j])**2
    return quality / len(cluster)
    
def compute_quality(data, cluster_indices):
    clusters = dict()
    for i, c in enumerate(cluster_indices):
        if c in clusters:
            clusters[c].append(data[i])
        else:
            clusters[c] = [data[i]]
    return sum(cluster_quality(c) for c in clusters.values())

def transitive_closure(ml, cl, n):
    ml_graph = dict()
    cl_graph = dict()
    for i in range(n):
        ml_graph[i] = set()
        cl_graph[i] = set()

    def add_both(d, i, j):
        d[i].add(j)
        d[j].add(i)

    for (i, j) in ml:
        add_both(ml_graph, i, j)

    def dfs(i, graph, visited, component):
        visited[i] = True
        for j in graph[i]:
            if not visited[j]:
                dfs(j, graph, visited, component)
        component.append(i)

    visited = [False] * n
    for i in range(n):
        if not visited[i]:
            component = []
            dfs(i, ml_graph, visited, component)
            for x1 in component:
                for x2 in component:
                    if x1 != x2:
                        ml_graph[x1].add(x2)
    for (i, j) in cl:
        add_both(cl_graph, i, j)
        for y in ml_graph[j]:
            add_both(cl_graph, i, y)
        for x in ml_graph[i]:
            add_both(cl_graph, x, j)
            for y in ml_graph[j]:
                add_both(cl_graph, x, y)

    for i in ml_graph:
        for j in ml_graph[i]:
            if j != i and j in cl_graph[i]:
                raise Exception('inconsistent constraints between %d and %d' %(i, j))

    return ml_graph, cl_graph

In [6]:
def get_points(line):
    line = line.split()
    points = []
    for entry in line:
        if not '..' in entry:
            points.append(int(entry))
        else:
            first, last = [int(i) for i in entry.split('..')]
            points += list(range(first, last+1))
    return points


In [7]:

def get_clusters(fname):
    clusters = dict()
    num_points = -1
    with open(fname, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('Cluster'):
                line = line.split(':')
                cluster_id = int(line[0][len('Cluster'):])
                points = line[1].strip()
                points = get_points(points)
                num_points = max(num_points, max(points))
                clusters[cluster_id] = points
    assignments = [-1] * (num_points+1)
    for cluster_id in clusters:
        for point_id in clusters[cluster_id]:
            assignments[point_id] = cluster_id - 1
    return assignments

In [14]:
def check_solved(fname):
    with open(fname) as f:
        line = f.readline()
    solved = ('problem is solved' in line)
    opt = ('[optimal solution found]' in line)
    return solved, opt

In [16]:
ds = 'soybean'
rootdir = '.'
datafile = '../../../data/%s/%s.data'%(ds, ds)
data = read_data(datafile)
with open('../../../data/%s/solutions/objective_values.txt'%ds, 'w') as f:
    print('#c\tk\tobj. value', file=f)
    print('-'*26, file=f)
    for fname in os.listdir(rootdir):
        if not fname.startswith(ds): continue
        x = re.match(r'%s\.(\d+)\.(\d+)'%ds, fname)
        ncons = int(x.group(1))
        k = int(x.group(2))
        fname = ojoin(rootdir, fname)
        solved, opt = check_solved(fname)
        if opt:
            cluster_indices = get_clusters(fname)
            with open('../../../data/%s/solutions/solution.%d.%d'%(ds, ncons, k), 'w') as g:
                print('\n'.join(str(j) for j in cluster_indices), file=g)
            print('%d\t%d\t%.4f'%(ncons, k, compute_quality(data, cluster_indices)), file=f)
        elif solved:
            print('%d\t%d\tINF'%(ncons, k), file=f)
        else:
            with open(fname) as h:
                quality = float(h.readline())
            print('%d\t%d\t~%.4f'%(ncons, k, quality), file=f)

In [18]:
def verify_constraints(datafile, clusterfile, consfile):
    data = read_data(datafile)
    ml, cl = read_constraints(consfile)
    ml, cl = transitive_closure(ml, cl, len(data))
    for point_id in range(len(data)):
        cluster_indices = get_clusters(clusterfile)
        ml_points = ml[point_id]
        for id2 in ml_points:
            if cluster_indices[point_id] != cluster_indices[id2]:
                return False
        cl_points = cl[point_id]
        for id2 in cl_points:
            if cluster_indices[point_id] == cluster_indices[id2]:
                return False
    return True

In [19]:
ds = 'soybean'
datafile = '../../../data/%s/%s.data'%(ds, ds)
rootdir = '.'
for fname in os.listdir(rootdir):
    if not fname.startswith(ds): continue
    solved, opt = check_solved(fname)
    if not opt: continue
    x = re.match(r'%s\.(\d+)\.\d+'%ds, fname)
    ncons = int(x.group(1))
    clusterfile = ojoin(rootdir, fname)
    consfile = '../../../data/%s/constraints/%s.%d.cons' %(ds, ds, ncons)
    res = verify_constraints(datafile, clusterfile, consfile)
    print(fname, res)

soybean.60.4 True
soybean.80.4 True
soybean.40.4 True
soybean.60.5 True
soybean.40.5 True
soybean.10.4 True
soybean.40.3 True
soybean.80.5 True
soybean.20.4 True
soybean.10.5 True
soybean.20.3 True
soybean.60.3 True
soybean.20.5 True
