In [1]:
import numpy as np
from ortools.linear_solver import pywraplp
import json
from collections import Counter
from tqdm.auto import tqdm

In [2]:
with open('../fedtask/mnist_cnum6_dist2_skew0.8_seed0/data.json', 'r', encoding='utf-8') as f1:
    fdata1 = json.load(f1)

In [3]:
N_CLASSES = 10

In [4]:
def counter_to_array(labels_):
    counter = Counter(labels_)
    result = np.zeros(N_CLASSES)
    for l in range(N_CLASSES):
        result[l] = counter[l]
    return result

In [5]:
N_test = counter_to_array(fdata1['dtest']['y'])
P_test = N_test / N_test.sum()
P_clients = dict()
for client in fdata1['client_names']:
    N_client = counter_to_array(fdata1[client]['dtrain']['y'])
    P_clients[client] = N_client / N_client.sum()

In [6]:
def wd(P, Q):
    solver = pywraplp.Solver.CreateSolver('CLP')
    pi = dict()
    for i in range(N_CLASSES):
        for j in range(N_CLASSES):
            pi[i, j] = solver.NumVar(0, solver.infinity(), 'pi[{}, {}]'.format(i, j))
    for i in range(N_CLASSES):
        solver.Add(sum(pi[i, j] for j in range(N_CLASSES)) == P[i])
    for j in range(N_CLASSES):
        solver.Add(sum(pi[i, j] for i in range(N_CLASSES)) == Q[j])
    solver.Minimize(sum(pi[i, j] for i in range(N_CLASSES) for j in range(N_CLASSES) if i != j))
    status = solver.Solve()
    if status == pywraplp.Solver.OPTIMAL:
        return solver.Objective().Value()
    else:
        print("No optimal solution!")
        return None

In [7]:
tmp = list()
for client in tqdm(fdata1['client_names']):
    tmp.append(wd(P_clients[client], P_test))
    print(client, tmp[-1])
tmp = np.array(tmp)
# tmp = tmp / tmp.sum()
tmp

  0%|          | 0/6 [00:00<?, ?it/s]

Client0 0.40500031919744645
Client1 0.39349032954381924
Client2 0.3764173535057017
Client3 0.25248232341779897
Client4 0.3892663233651818
Client5 0.32802284800445003


array([0.40500032, 0.39349033, 0.37641735, 0.25248232, 0.38926632,
       0.32802285])

In [8]:
# Multivariate random variables
np.random.seed(0)
a = np.random.rand(32, 128)
b = np.random.rand(64, 128)
solver = pywraplp.Solver.CreateSolver('CLP')
pi = dict()
for i in range(32):
    for j in range(64):
        pi[i, j] = solver.NumVar(0, solver.infinity(), 'pi[{}, {}]'.format(i, j))
for i in range(32):
    solver.Add(sum(pi[i, j] for j in range(64)) == 1.0 / 32)
for j in range(64):
    solver.Add(sum(pi[i, j] for i in range(32)) == 1.0 / 64)
solver.Minimize(sum(pi[i, j] * np.linalg.norm(a[i] - b[j], 2) for i in range(32) for j in range(64)))
status = solver.Solve()
if status == pywraplp.Solver.OPTIMAL:
    print(solver.Objective().Value())

4.224976908444024


In [11]:
x = 2 * tmp.mean() - tmp
x = x / x.sum()
x

array([0.14449378, 0.14986054, 0.15782116, 0.21560837, 0.15183007,
       0.18038608])