Paper: https://arxiv.org/abs/2102.07831


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
rc('animation', html='jshtml')

import torch

In [None]:
N = 10
s = np.random.randn(N)
s

In [None]:
original_idx = np.arange(N)
sorted_idx = np.argsort(s)[::-1]
print(sorted_idx)
[(i, j) for i, j in zip(original_idx, sorted_idx)]

In [None]:
# P_sort = np.zeros((N, N))
P_sort = np.arange(N*N).reshape(N, N)
P_sort

In [None]:
P_sort[original_idx, sorted_idx]

In [None]:
P_sort = np.zeros((N, N))
P_sort[original_idx, sorted_idx] = 1
P_sort

In [None]:
[
    (x, y)
    for x, y in zip(s[sorted_idx], P_sort.dot(s))
]

In [None]:
As = np.abs(s[:, None] - s[None, :])
As

In [None]:
plt.imshow(As, cmap="Reds")
plt.colorbar()

In [None]:
As.dot(np.ones(N))

In [None]:
As.dot(np.ones(N))[None, :]

In [None]:
from scipy.special import softmax

In [None]:
(N + 1 - 2*original_idx[:, None])*s[None, :]

In [None]:
tau = 1e-10

P_sort_hat_iter = np.zeros((N, N))
for i in range(N):
  P_sort_hat_iter[i, :] = softmax(((N + 1 - 2*(i+1))*s - As.dot(np.ones(N))) / tau)

P_sort_hat = softmax(((N + 1 - 2*(original_idx[:, None] + 1))*s - As.dot(np.ones(N))) / tau, axis=1)
P_sort_hat, P_sort_hat_iter

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 5))

for i, (p, l) in enumerate([
    (P_sort, "P_sort"),
    (P_sort_hat_iter, "P_sort_hat_iter"),
    (P_sort_hat, "P_sort_hat"),
]):
  im = ax[i].imshow(p, cmap="Reds")
  ax[i].set_title(l)
# plt.colorbar(im)
fig.colorbar(im, ax=ax[:], shrink=0.95, location='bottom')

# Optimize Scores via PyTorch

In [None]:
# P_sort_hat = softmax(((N + 1 - 2*(original_idx[:, None] + 1))*s - As.dot(np.ones(N))) / tau, axis=1)
s_tc = torch.tensor(s, requires_grad=True)
print(s_tc)

As_tc = (s_tc[:, None] - s_tc[None, :]).abs()
print(As_tc.shape)

tau = 1.0

P_sort_hat_logits_tc = (
    (N + 1 - 2*(torch.tensor(original_idx[:, None]) + 1))*s_tc - As_tc @ (torch.ones(N, 1, dtype=As_tc.dtype))
) / tau


P_sort_hat_tc = torch.nn.functional.softmax(P_sort_hat_logits_tc, dim=1)
print(P_sort_hat_tc.shape)



loss = torch.nn.functional.cross_entropy(P_sort_hat_logits_tc, torch.tensor(sorted_idx.copy()))
print(loss)


# P_sort_hat = softmax(((N + 1 - 2*(original_idx[:, None] + 1))*s - As.dot(np.ones(N))) / tau, axis=1)


In [None]:
loss.backward()

In [None]:
s_tc, s_tc.grad

In [None]:
fig, ax = plt.subplots(2, 1)
im = ax[0].imshow(s_tc.detach().numpy().reshape(1, -1))
fig.colorbar(im, ax=ax[0], location="bottom")
im = ax[1].imshow(s_tc.grad.detach().numpy().reshape(1, -1))
fig.colorbar(im, ax=ax[1], location="bottom")


In [None]:
def get_loss(s_tc, sorted_idx, tau = 1.0):
  # s_tc = torch.tensor(s, requires_grad=True)
  # print(s_tc)

  As_tc = (s_tc[:, None] - s_tc[None, :]).abs()
  # print(As_tc.shape)

  P_sort_hat_logits_tc = (
      (N + 1 - 2*(torch.tensor(original_idx[:, None]) + 1))*s_tc - As_tc @ (torch.ones(N, 1, dtype=As_tc.dtype))
  ) / tau


  P_sort_hat_tc = torch.nn.functional.softmax(P_sort_hat_logits_tc, dim=1)
  # print(P_sort_hat_tc.shape)
  loss = torch.nn.functional.cross_entropy(P_sort_hat_logits_tc, torch.tensor(sorted_idx.copy()))
  # print(loss)
  return loss

In [None]:
from tqdm.auto import tqdm, trange

In [None]:
losses = []
s_tc_vals = []
s_tc_grad_vals = []

s_tc = torch.tensor(s, requires_grad=True)
sorted_idx_label = np.random.permutation(N)
print(f"{sorted_idx_label=}")


optimizer = torch.optim.Adam([s_tc])

s_tc_vals.append(s_tc.detach().numpy())

for i in trange(5000):
    optimizer.zero_grad()
    loss = get_loss(s_tc, sorted_idx_label, tau = 1e-3)
    # print(f"{i=}\t{loss=}")
    loss.backward()
    # print(f"{s_tc.grad=}")
    optimizer.step()
    losses.append(loss.detach().numpy().item())
    s_tc_vals.append(s_tc.detach().numpy().copy())
    s_tc_grad_vals.append(s_tc.grad.detach().numpy().copy())

In [None]:
plt.plot(losses)

In [None]:
# fig = plt.figure(figsize=(15, 15))
plt.imshow(np.stack(s_tc_vals, axis=0), aspect='auto')
plt.title(str(sorted_idx_label))
plt.colorbar()

In [None]:
# fig = plt.figure(figsize=(15, 15))
plt.imshow(np.stack(s_tc_grad_vals, axis=0), aspect='auto', cmap="bwr")
plt.title(str(sorted_idx_label))
plt.colorbar()

In [None]:
def optimize_scores(s_tc, sorted_idx_label, tau=1.0):
  losses = []
  s_tc_vals = []
  s_tc_grad_vals = []
  optimizer = torch.optim.Adam([s_tc])

  s_tc_vals.append(s_tc.detach().numpy())

  for i in trange(5000):
      optimizer.zero_grad()
      loss = get_loss(s_tc, sorted_idx_label, tau=tau)
      # print(f"{i=}\t{loss=}")
      loss.backward()
      # print(f"{s_tc.grad=}")
      optimizer.step()
      losses.append(loss.detach().numpy().item())
      s_tc_vals.append(s_tc.detach().numpy().copy())
      s_tc_grad_vals.append(s_tc.grad.detach().numpy().copy())

  s_tc_vals = np.stack(s_tc_vals, axis=0)
  s_tc_grad_vals = np.stack(s_tc_grad_vals, axis=0)

  return losses, s_tc_vals, s_tc_grad_vals

In [None]:
tau_stats = []

taus = [1e-3, 1e-2, 1e-1, 1.0, 10.0, 100.0]
sorted_idx_label = np.random.permutation(N)
print(f"{sorted_idx_label=}")

for tau in tqdm(taus):
  print(tau)
  s_tc = torch.tensor(s, requires_grad=True)
  stats = optimize_scores(s_tc, sorted_idx_label, tau=tau)
  tau_stats.append((tau,) + stats)

In [None]:
sorted_idx_label

In [None]:
for i, tau in enumerate(taus):
  plt.plot(tau_stats[i][1], label=f"{tau=:.3g}")
  print(f"{tau=:>5.5g}, {min(tau_stats[i][1])=:.5g}, {max(tau_stats[i][1])=:.5g}")

plt.yscale("log")
plt.legend()

In [None]:
fig, ax = plt.subplots(2, 3, figsize=(20, 10))
vmin = min([s[2].min() for s in tau_stats])
vmax = max([s[2].max() for s in tau_stats])
for i, (tau, axi) in enumerate(zip(taus, ax.flatten())):
  im = axi.imshow(tau_stats[i][2], aspect='auto')
  axi.set_title(f"{tau=:.3g}")
  fig.colorbar(im, ax=axi, location="right")
# fig.colorbar(im, ax=ax.flatten(), location="right")
fig.suptitle(f"{sorted_idx_label=}")

In [None]:
fig, ax = plt.subplots(2, 3, figsize=(20, 10))
vmin = min([s[3].min() for s in tau_stats])
vmax = max([s[3].max() for s in tau_stats])
for i, (tau, axi) in enumerate(zip(taus, ax.flatten())):
  im = axi.imshow(tau_stats[i][3], aspect='auto')
  axi.set_title(f"{tau=:.3g}")
  fig.colorbar(im, ax=axi, location="right")
fig.suptitle(f"{sorted_idx_label=}")

# Animate

In [None]:
taus = [1e-6, 1e-3, 0.1, 1]
fig, ax = plt.subplots(1, 1 + len(taus), figsize=(5*(1 + len(taus)), 5))

i = 0
p = P_sort
l = "P_sort"
im = ax[i].imshow(p, cmap="Reds")
ax[i].set_title(l)
fig.colorbar(im, ax=ax[:], shrink=0.95, location='bottom')

for i, tau in enumerate(taus, start=1):
  P_sort_hat_t = softmax(((N + 1 - 2*(original_idx[:, None] + 1))*s - As.dot(np.ones(N))) / tau, axis=1)
  im = ax[i].imshow(P_sort_hat_t, cmap="Greys")
  ax[i].set_title(f"P_sort_hat_t[{tau=}]")
# plt.colorbar(im)


## Animate

In [None]:
from matplotlib.animation import FuncAnimation, ArtistAnimation

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(5*2, 5))


i = 0
p = P_sort
l = "$P_{sort}$"
im = ax[i].imshow(p, cmap="Reds")
ax[i].set_title(l)
fig.colorbar(im, ax=ax[:], shrink=0.95, location='bottom')


# taus = [1e-6, 1e-3, 0.1, 1]
taus = np.logspace(-10, 1.01, 20)
# taus = np.linspace(0.01, 10, 20)

i=1

p = softmax(((N + 1 - 2*(original_idx[:, None] + 1))*s - As.dot(np.ones(N))) / tau, axis=1)
l = "$P_{sort}^{approx}$"
im = ax[i].imshow(p, cmap="Reds", animated=True)
ax_title = ax[i].set_title(f"{l} [{tau=}]")


def init():
  tau = taus[0]
  p = softmax(((N + 1 - 2*(original_idx[:, None] + 1))*s - As.dot(np.ones(N))) / tau, axis=1)
  im.set_data(p)
  ax_title.set_text(f"{l} [{tau=:.4g}]")
  return im, ax_title,

def update(frame):
  tau = taus[frame]
  p = softmax(((N + 1 - 2*(original_idx[:, None] + 1))*s - As.dot(np.ones(N))) / tau, axis=1)
  im.set_array(p)
  ax_title.set_text(f"{l} [{tau=:.4g}]")
  return im, ax_title,

ani = FuncAnimation(
  fig, update,
  init_func=init,
  frames=range(len(taus)),
  interval=100,
  blit=True,
  repeat_delay=5000
)
# plt.show()

ani.save("movie.mp4")

ani

In [None]:
def annotate_axes(ax, text, fontsize=18):
    ax.text(0.5, 0.5, text, transform=ax.transAxes,
            ha="center", va="center", fontsize=fontsize, color="darkgrey")


inner = [['innerA'],
         ['innerB']]
outer = [['upper left',  inner],
          ['lower left', 'lower right']]

fig, axd = plt.subplot_mosaic(outer, layout="constrained")
for k, ax in axd.items():
    annotate_axes(ax, f'axd[{k!r}]')

In [None]:
outer = [
    ["A", "B"],
    ["A", "B"],
    ["A", "B"],
    ["C", "D"],
    ["C", "D"],
    ["E", "E"],
]

fig, axd = plt.subplot_mosaic(outer, layout="constrained")
for k, ax in axd.items():
    annotate_axes(ax, f'axd[{k!r}]')

In [None]:
np.hstack([np.logspace(-4, -1, 10), np.logspace(-1, 2, 40)])

In [None]:
10**1.10

In [None]:
outer = [
    ["A", "B"],
    ["A", "B"],
    ["A", "B"],
    ["O", "O"],
    ["C", "D"],
    # ["C", "D"],
    # ["E", "E"],
]

fig, ax = plt.subplot_mosaic(outer, layout="constrained", figsize=(12, 10))

k = "A"
p = P_sort
l = "$P_{sort}$"
im = ax[k].imshow(p, cmap="Reds")
ax[k].set_title(l)
fig.colorbar(im, ax=[ax["A"], ax["B"]], shrink=0.99, location='bottom')


k = "O"
im_s = ax[k].imshow(s[None, :], cmap="Greys")
ax[k].set_title("Original Scores ($s$)")
ax[k].set(yticklabels=[])
# ax[k].set_axis_off()

k = "C"
im_s = ax[k].imshow(p.dot(s)[None, :], cmap="Greys")
ax[k].set_title(f"Sorted Scores ({l}$.dot(s)$)")
ax[k].set_axis_off()
fig.colorbar(im_s, ax=[ax["C"], ax["D"]], shrink=0.99, location='bottom')



# taus = [1e-6, 1e-3, 0.1, 1]
# taus = np.logspace(-5, 1.10, 50)[::-1]
taus = np.hstack([np.logspace(-4, -1, 10), np.logspace(-1, 2, 40)])[::-1]
# taus = np.linspace(0.01, 10, 20)

tau = 0.5

k = "B"

p = softmax(((N + 1 - 2*(original_idx[:, None] + 1))*s - As.dot(np.ones(N))) / tau, axis=1)
l = "$P_{sort}^{approx}$"
im = ax[k].imshow(p, cmap="Reds", animated=True)
ax_title = ax[k].set_title(f"{l} [{tau=}]")

k = "D"
im_s = ax[k].imshow(p.dot(s)[None, :], cmap="Greys", animated=True)
ax[k].set_axis_off()
ax[k].set_title(f"Sorted Scores ({l}$.dot(s)$)")

fig.suptitle("Soft Sort")


def init():
  tau = taus[0]
  p = softmax(((N + 1 - 2*(original_idx[:, None] + 1))*s - As.dot(np.ones(N))) / tau, axis=1)
  im.set_data(p)
  im_s.set_data(p.dot(s)[None, :])
  ax_title.set_text(f"{l} [{tau=:.4g}]")
  return im, ax_title, im_s

def update(frame):
  tau = taus[frame]
  p = softmax(((N + 1 - 2*(original_idx[:, None] + 1))*s - As.dot(np.ones(N))) / tau, axis=1)
  im.set_array(p)
  im_s.set_array(p.dot(s)[None, :])
  ax_title.set_text(f"{l} [{tau=:.4g}]")
  return im, ax_title, im_s

ani = FuncAnimation(
  fig, update,
  init_func=init,
  frames=range(len(taus)),
  interval=100,
  blit=True,
  repeat_delay=5000
)
# plt.show()

# ani.save("movie.mp4")

ani


In [None]:
ani.save("movie.mp4", dpi=100)

In [None]:
outer = [
    ["A", "B"],
    ["A", "B"],
    ["A", "B"],
    ["O", "O"],
    ["C", "D"],
    # ["C", "D"],
    # ["E", "E"],
]

fig, ax = plt.subplot_mosaic(outer, layout="constrained", figsize=(12, 10))

k = "A"
p = P_sort
l = "$P_{sort}$"
im = ax[k].imshow(p, cmap="Reds")
ax[k].set_title(l)
fig.colorbar(im, ax=[ax["A"], ax["B"]], shrink=0.99, location='bottom')


k = "O"
im_s = ax[k].imshow(s[None, :], cmap="Greys")
ax[k].set_title("Original Scores ($s$)")
ax[k].set(yticklabels=[])
# ax[k].set_axis_off()

k = "C"
im_s = ax[k].imshow(p.dot(s)[None, :], cmap="Greys")
ax[k].set_title(f"{l}$.dot(s)$")
ax[k].set_axis_off()
fig.colorbar(im_s, ax=[ax["C"], ax["D"]], shrink=0.99, location='bottom')



# taus = [1e-6, 1e-3, 0.1, 1]
# taus = np.logspace(-10, 1.01, 20)
# taus = np.linspace(0.01, 10, 20)

tau = 0.5

k = "B"

p = softmax(((N + 1 - 2*(original_idx[:, None] + 1))*s - As.dot(np.ones(N))) / tau, axis=1)
l = "$P_{sort}^{approx}$"
im = ax[k].imshow(p, cmap="Reds", animated=True)
ax_title = ax[k].set_title(f"{l} [{tau=}]")

k = "D"
im_s = ax[k].imshow(p.dot(s)[None, :], cmap="Greys")
ax[k].set_axis_off()
ax[k].set_title(f"{l}$.dot(s)$")


fig.suptitle("Soft Sort")


In [None]:

taus = [1e-6, 1e-3, 0.1, 1]
fig, ax = plt.subplots(1, 2, figsize=(5*2, 5))


i = 0
p = P_sort
l = "P_sort"
im = ax[i].imshow(p, cmap="Reds")
ax[i].set_title(l)
fig.colorbar(im, ax=ax[:], shrink=0.95, location='bottom')

i=1

p = softmax(((N + 1 - 2*(original_idx[:, None] + 1))*s - As.dot(np.ones(N))) / tau, axis=1)
l = "P_sort[approx]"



ims = []
for frame, tau in enumerate(taus):
  p = softmax(((N + 1 - 2*(original_idx[:, None] + 1))*s - As.dot(np.ones(N))) / tau, axis=1)
  im = ax[i].imshow(p, cmap="Reds", animated=True)
  ax_title = ax[i].set_title(f"{l} [{tau=}]")
  if frame == 0:
    ax[i].imshow(p, cmap="Reds")  # show an initial one first
    ax[i].set_title("")
  ims.append([im, ax_title])


ani = ArtistAnimation(fig, ims, interval=1000, blit=True,
                                repeat_delay=5000)

ani.save("movie.mp4")

ani