In [2]:
%load_ext autoreload
%autoreload 2

## (Strategy Proofness - First Order Stochastic Dominance)
$$
\forall i\in W\cup F \ \forall \succ_i \forall \succ_{-i} \forall \succ'_i \forall j \\
\sum_{j'\succeq j}(g_{ij'}(\succ'_i,\succ_{-i})-g_{ij'}(\succ_i,\succ_{-i})) \leq 0
$$

## (Ex-ante Stability)
$\nexists (w,f)\in W\times F$ s.t. $\exist f'\ [g_{wf'}(\succ)>0\land f\succ_w f']\ \exist w'\ [g_{w'f}(\succ)>0\land w\succ_f w']$

## (Stability of Deterministic Matching)
$$
\forall (w,f)\in W\times F \ g_{wf}+\sum_{f'\succ_w f}g_{wf'}+\sum_{w'\succ_f w}g_{w'f}\geq 1
$$

## (Ex-post Stability)
A randomized matching is **ex-post stable** iff it can be decomposed into deterministic stable matchings.

## (Fractionally Stable)
$$
\forall (w,f)\in W\times F \ g_{wf}+\sum_{f'\succ_w f}g_{wf'}+\sum_{w'\succ_f w}g_{w'f}\geq 1
$$

### (Violation of Fractionally Stability)
$$
\sum_\succ\sum_w\sum_f\max\left\{0,1-g_{wf}(\succ)-\sum_{w'\succ_f w}g_{w'f}(\succ)-\sum_{f'\succ_w f}g_{wf'}(\succ)\right\}
$$

## (Primal)
$$
\begin{align*}
    \min & \sum_\succ\sum_w\sum_f t_{wf}(\succ)\\
    \text{s.t.} & \sum_f g_{wf}(\succ)\leq 1 & \forall\succ\forall w \\
    & \sum_w g_{wf}(\succ)\leq 1 & \forall \succ\forall f\\
    & t_{wf}(\succ)\geq 1-g_{wf}(\succ)-\sum_{w'\succ_f w}g_{w'f}(\succ)-\sum_{f'\succ_w f}g_{wf'}(\succ) & \forall\succ\forall w\forall f\\
    & \sum_{f'\succ_wf}(g_{wf'}(\succ_w',\succ_{-w})-g_{wf'}(\succ))\leq 0 & \forall\succ\forall w\forall\succ_{w}'\forall f\\
    & \sum_{w'\succ_fw}(g_{w'f}(\succ_f',\succ_{-f})-g_{w'f}(\succ))\leq 0 & \forall\succ\forall f\forall\succ_{f}'\forall w\\
    & g_{wf}(\succ)\geq 0,\ t_{wf}(\succ)\geq 0 & \forall\succ\forall w \forall y
\end{align*}
$$



## (Dual)
$$
\begin{align*}
    \min & \sum_\succ\left(\sum_wx_w(\succ)+\sum_fy_f(\succ)-\sum_w\sum_fz_{wf}(\succ)\right)\\
    \text{s.t.}  \\
    & \forall \succ \forall w \forall f\\
    & x_w(\succ)+y_f(\succ)-z_{wf}(\succ)-\sum_{f'\prec_wf}z_{wf'}(\succ)-\sum_{w'\prec_fw}z_{w'f}(\succ)-\sum_{\succ_w'}\left(\sum_{f'\prec_w f}u_{wf'}(\succ_w',\succ_w,\succ_{-w})-\sum_{f'\prec_w'f}u_{wf'}(\succ_w,\succ_w',\succ_{-w})\right)-\sum_{\succ_f'}\left(\sum_{w'\prec_fw}v_{w'f}(\succ_f',\succ_f,\succ_{-f})-\sum_{w'\prec_f'w}v_{w'f}(\succ_f,\succ_f',\succ_{-f})\right)\geq 0 & \forall\succ\forall w\forall f\\
    & x_w(\succ)\geq 0,\ y_f(\succ)\geq 0,\ 0\leq z_{wf}(\succ)\leq 1 & \forall\succ\forall w\forall f\\
    & u_{wf}(\succ'_w,\succ_w,\succ_{-w})\geq 0 & \forall\succ\forall w\forall\succ_w'\forall f\\
    & v_{wf}(\succ'_f,\succ_f,\succ_{-f})\geq 0 & \forall\succ\forall f\forall\succ_f'\forall w
\end{align*}
$$

In [6]:
import os
import sys
import time
import logging
import argparse
import numpy as np
from random import random
import itertools
from pathlib import Path

sys.path.append(str(Path("primal_dual_matching.ipynb").resolve().parent.parent))

import torch
import torch.nn
from torch import optim
import torch.nn.functional as F

from data import Data

from primal_net import PrimalNet
from primal_loss import *
from primal_train import *

from dual_net import DualNet
from dual_loss import *
from dual_train import *

import torch
import matplotlib.pyplot as plt
import seaborn as sns

### 2*2

In [7]:
device = "mps" #if torch.cuda.is_available() else "cpu"
lambd = np.ones((2,2))*0.001
# lambd = cfg.lambd

cfg = HParams(num_agents = 2,
              device = device,
              lambd = lambd,
              rho = 0.1,
              lagr_iter = 1000,
              batch_size = 512)

cfg.lr = 1e-4

np.random.seed(cfg.seed)

G = Data(cfg)

model = PrimalNet(cfg)
model.to(device)

train_primal(cfg,G,model)

2024-12-05 17:51:34,766:INFO:[TRAIN-ITER]: 0, [Time-Elapsed]: 0.745740, [Total-Loss]: 1288113606067486720.000000
2024-12-05 17:51:34,767:INFO:[CONSTR-Vio]: 0.000153, [OBJECTIVE]: 1288113606067486720.000000
2024-12-05 17:51:37,481:INFO:[TRAIN-ITER]: 100, [Time-Elapsed]: 3.461715, [Total-Loss]: 0.266347
2024-12-05 17:51:37,483:INFO:[CONSTR-Vio]: 0.000073, [OBJECTIVE]: 0.266347
2024-12-05 17:51:40,168:INFO:[TRAIN-ITER]: 200, [Time-Elapsed]: 6.147922, [Total-Loss]: 0.257339
2024-12-05 17:51:40,168:INFO:[CONSTR-Vio]: 0.000070, [OBJECTIVE]: 0.257339
2024-12-05 17:51:42,902:INFO:[TRAIN-ITER]: 300, [Time-Elapsed]: 8.879083, [Total-Loss]: 0.246333
2024-12-05 17:51:42,903:INFO:[CONSTR-Vio]: 0.000071, [OBJECTIVE]: 0.246333
2024-12-05 17:51:45,778:INFO:[TRAIN-ITER]: 400, [Time-Elapsed]: 11.758585, [Total-Loss]: 0.251336
2024-12-05 17:51:45,780:INFO:[CONSTR-Vio]: 0.000074, [OBJECTIVE]: 0.251336
2024-12-05 17:51:48,629:INFO:[TRAIN-ITER]: 500, [Time-Elapsed]: 14.609509, [Total-Loss]: 0.228320
2024-12

[[0.00100087 0.00100226]
 [0.00100101 0.00100244]]


2024-12-05 17:52:05,296:INFO:	[VAL-ITER]: 1000, [LOSS]: 0.256776, [Constr-vio]: 0.000071, [Objective]: 0.256776
2024-12-05 17:52:07,986:INFO:[TRAIN-ITER]: 1100, [Time-Elapsed]: 33.962279, [Total-Loss]: 0.250337
2024-12-05 17:52:07,988:INFO:[CONSTR-Vio]: 0.000073, [OBJECTIVE]: 0.250337
2024-12-05 17:52:10,636:INFO:[TRAIN-ITER]: 1200, [Time-Elapsed]: 36.611169, [Total-Loss]: 0.256344
2024-12-05 17:52:10,637:INFO:[CONSTR-Vio]: 0.000069, [OBJECTIVE]: 0.256344
2024-12-05 17:52:13,359:INFO:[TRAIN-ITER]: 1300, [Time-Elapsed]: 39.339820, [Total-Loss]: 0.262347
2024-12-05 17:52:13,360:INFO:[CONSTR-Vio]: 0.000068, [OBJECTIVE]: 0.262346
2024-12-05 17:52:16,014:INFO:[TRAIN-ITER]: 1400, [Time-Elapsed]: 41.991992, [Total-Loss]: 0.277360
2024-12-05 17:52:16,018:INFO:[CONSTR-Vio]: 0.000070, [OBJECTIVE]: 0.277360
2024-12-05 17:52:18,691:INFO:[TRAIN-ITER]: 1500, [Time-Elapsed]: 44.671374, [Total-Loss]: 0.255348
2024-12-05 17:52:18,692:INFO:[CONSTR-Vio]: 0.000070, [OBJECTIVE]: 0.255348
2024-12-05 17:52:2

[[0.00100152 0.00100505]
 [0.00100211 0.00100475]]


2024-12-05 17:52:35,704:INFO:	[VAL-ITER]: 2000, [LOSS]: 0.255555, [Constr-vio]: 0.000071, [Objective]: 0.255555
2024-12-05 17:52:38,475:INFO:[TRAIN-ITER]: 2100, [Time-Elapsed]: 64.455088, [Total-Loss]: 0.241336
2024-12-05 17:52:38,477:INFO:[CONSTR-Vio]: 0.000066, [OBJECTIVE]: 0.241336
2024-12-05 17:52:41,138:INFO:[TRAIN-ITER]: 2200, [Time-Elapsed]: 67.114496, [Total-Loss]: 0.264342
2024-12-05 17:52:41,139:INFO:[CONSTR-Vio]: 0.000067, [OBJECTIVE]: 0.264342
2024-12-05 17:52:43,941:INFO:[TRAIN-ITER]: 2300, [Time-Elapsed]: 69.921158, [Total-Loss]: 0.255348
2024-12-05 17:52:43,943:INFO:[CONSTR-Vio]: 0.000068, [OBJECTIVE]: 0.255348
2024-12-05 17:52:46,586:INFO:[TRAIN-ITER]: 2400, [Time-Elapsed]: 72.566380, [Total-Loss]: 0.263343
2024-12-05 17:52:46,587:INFO:[CONSTR-Vio]: 0.000071, [OBJECTIVE]: 0.263343
2024-12-05 17:52:49,296:INFO:[TRAIN-ITER]: 2500, [Time-Elapsed]: 75.276191, [Total-Loss]: 0.242336
2024-12-05 17:52:49,297:INFO:[CONSTR-Vio]: 0.000075, [OBJECTIVE]: 0.242336
2024-12-05 17:52:5

[[0.00100265 0.00100773]
 [0.00100314 0.00100719]]


2024-12-05 17:53:05,981:INFO:	[VAL-ITER]: 3000, [LOSS]: 0.256476, [Constr-vio]: 0.000071, [Objective]: 0.256476


KeyboardInterrupt: 

### 3* 3

In [8]:
device = "mps" #if torch.cuda.is_available() else "cpu"
lambd = np.ones((3,3))*0.001
# lambd = cfg.lambd

cfg = HParams(num_agents = 3,
              device = device,
              lambd = lambd,
              rho = 0.1,
              lagr_iter = 1000,
              batch_size = 512)

cfg.lr = 1e-4

np.random.seed(cfg.seed)

G = Data(cfg)

In [9]:
model = PrimalNet(cfg)
model.to(device)

PrimalNet(
  (input_block): Sequential(
    (0): Linear(in_features=18, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): LeakyReLU(negative_slope=0.01)
    (8): Linear(in_features=256, out_features=256, bias=True)
    (9): LeakyReLU(negative_slope=0.01)
  )
  (layer_out): Linear(in_features=256, out_features=9, bias=True)
)

In [10]:
train_primal(cfg,G,model)

2024-12-05 17:53:13,873:INFO:[TRAIN-ITER]: 0, [Time-Elapsed]: 1.593212, [Total-Loss]: 4084071070187913216.000000
2024-12-05 17:53:13,874:INFO:[CONSTR-Vio]: 0.005378, [OBJECTIVE]: 4084071070187913216.000000
2024-12-05 17:53:18,109:INFO:[TRAIN-ITER]: 100, [Time-Elapsed]: 5.823276, [Total-Loss]: 0.895723
2024-12-05 17:53:18,110:INFO:[CONSTR-Vio]: 0.004819, [OBJECTIVE]: 0.895718
2024-12-05 17:53:22,593:INFO:[TRAIN-ITER]: 200, [Time-Elapsed]: 10.306717, [Total-Loss]: 0.894962
2024-12-05 17:53:22,596:INFO:[CONSTR-Vio]: 0.004813, [OBJECTIVE]: 0.894957
2024-12-05 17:53:26,874:INFO:[TRAIN-ITER]: 300, [Time-Elapsed]: 14.587445, [Total-Loss]: 0.891707
2024-12-05 17:53:26,875:INFO:[CONSTR-Vio]: 0.004768, [OBJECTIVE]: 0.891703
2024-12-05 17:53:31,336:INFO:[TRAIN-ITER]: 400, [Time-Elapsed]: 19.049638, [Total-Loss]: 0.881309
2024-12-05 17:53:31,339:INFO:[CONSTR-Vio]: 0.004785, [OBJECTIVE]: 0.881304
2024-12-05 17:53:35,656:INFO:[TRAIN-ITER]: 500, [Time-Elapsed]: 23.369539, [Total-Loss]: 0.897606
2024-

[[0.00106605 0.00107141 0.0010512 ]
 [0.0010532  0.00106059 0.00105216]
 [0.00104644 0.00104299 0.00104381]]


2024-12-05 17:54:03,921:INFO:	[VAL-ITER]: 1000, [LOSS]: 0.889024, [Constr-vio]: 0.004837, [Objective]: 0.889019
2024-12-05 17:54:08,413:INFO:[TRAIN-ITER]: 1100, [Time-Elapsed]: 56.126890, [Total-Loss]: 0.909126
2024-12-05 17:54:08,415:INFO:[CONSTR-Vio]: 0.005652, [OBJECTIVE]: 0.909120
2024-12-05 17:54:12,792:INFO:[TRAIN-ITER]: 1200, [Time-Elapsed]: 60.506017, [Total-Loss]: 0.930179
2024-12-05 17:54:12,793:INFO:[CONSTR-Vio]: 0.005614, [OBJECTIVE]: 0.930173
2024-12-05 17:54:17,110:INFO:[TRAIN-ITER]: 1300, [Time-Elapsed]: 64.824557, [Total-Loss]: 0.939986
2024-12-05 17:54:17,111:INFO:[CONSTR-Vio]: 0.005609, [OBJECTIVE]: 0.939980
2024-12-05 17:54:21,574:INFO:[TRAIN-ITER]: 1400, [Time-Elapsed]: 69.287674, [Total-Loss]: 0.892613
2024-12-05 17:54:21,577:INFO:[CONSTR-Vio]: 0.005718, [OBJECTIVE]: 0.892606


KeyboardInterrupt: 

In [11]:
def plot_matching(p, q, match):
    output_matrix = match.squeeze().detach().numpy()

    annotations = np.empty_like(output_matrix, dtype=object)
    for i in range(output_matrix.shape[0]):
        for j in range(output_matrix.shape[1]):
            annotations[i, j] = f'{output_matrix[i, j]:.2e}\n[{p[i, j]}, {q[j, i]}]'

    # ヒートマップのプロット
    plt.figure(figsize=(8, 6))
    sns.heatmap(output_matrix, annot=annotations, fmt='', cmap='Blues', cbar=True)
    plt.title("Agent Relationship Heatmap with Vector Details")
    plt.xlabel("Agent")
    plt.ylabel("Agent")
    plt.show()

In [12]:
p = torch.tensor([[1, 0.0, 0], [0, 1, 0], [0, 0, 1]])
q = torch.tensor([[1, 0.0, 0], [0, 1, 0], [0, 0, 1]])

output = model(p, q)

RuntimeError: Tensor for argument input is on cpu but expected on mps

In [None]:
plot_matching(p, q, output)