In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import softmax, logsumexp
import seaborn as sns

In [None]:
def sample_rankings(log_scores, n_samples, cutoff=None, prob_per_rank=False):
  n_docs = log_scores.shape[0]
  ind = np.arange(n_samples)

  if cutoff:
    ranking_len = min(n_docs, cutoff)
  else:
    ranking_len = n_docs

  if prob_per_rank:
    rank_prob_matrix = np.empty((ranking_len, n_docs), dtype=np.float64)

  log_scores = np.tile(log_scores[None,:], (n_samples, 1))
  rankings = np.empty((n_samples, ranking_len), dtype=np.int32)
  inv_rankings = np.empty((n_samples, n_docs), dtype=np.int32)
  rankings_prob = np.empty((n_samples, ranking_len), dtype=np.float64)

  if cutoff:
    inv_rankings[:] = ranking_len

  for i in range(ranking_len):
    log_scores += 18 - np.amax(log_scores, axis=1)[:, None]
    log_denom = np.log(np.sum(np.exp(log_scores), axis=1))
    probs = np.exp(log_scores - log_denom[:, None])
    print(f"{i=}")
    plt.figure()
    sns.heatmap(probs)
    plt.title(f"{i=}")
    if prob_per_rank:
      rank_prob_matrix[i, :] = np.mean(probs, axis=0)
    cumprobs = np.cumsum(probs, axis=1)
    random_values = np.random.uniform(size=n_samples)
    greater_equal_mask = np.greater_equal(random_values[:,None], cumprobs)
    sampled_ind = np.sum(greater_equal_mask, axis=1)

    rankings[:, i] = sampled_ind
    inv_rankings[ind, sampled_ind] = i
    rankings_prob[:, i] = probs[ind, sampled_ind]
    log_scores[ind, sampled_ind] = np.NINF

  if prob_per_rank:
    return rankings, inv_rankings, rankings_prob, rank_prob_matrix
  else:
    return rankings, inv_rankings, rankings_prob

In [None]:
N = 10
n_samples = 3
log_scores = np.random.randn(N)
log_scores

In [None]:
rankings, inv_rankings, rankings_prob, rank_prob_matrix = sample_rankings(log_scores, n_samples, prob_per_rank=True, cutoff=5)
rankings, inv_rankings, rankings_prob, rank_prob_matrix

In [None]:
rankings.shape, inv_rankings.shape, rankings_prob.shape, rank_prob_matrix.shape

In [None]:
probs = np.exp(log_scores - log_scores.max())
probs = probs / probs.sum()
probs

In [None]:
np.sort(probs)[-2]

In [None]:
plt.plot(probs, marker="o", color="r")
for i in range(3): # top 3 probs
# plt.axhline(y=probs.max(), linestyle=":", lw=0.5, color="0.5")
  plt.axhline(y=np.sort(probs)[-(i+1)], linestyle=":", lw=0.5, color="0.5")
# plt.plot(probs.cumsum(), marker="s", color="k")
for y in rankings_prob:
  plt.plot(y, linestyle="--", marker="x")

In [None]:
np.arange(5)[-2:]

In [None]:
plt.plot(probs, marker="o", color="r")
idx = np.argsort(probs)
for i in range(3): # top 3 probs
# plt.axhline(y=probs.max(), linestyle=":", lw=0.5, color="0.5")

  plt.axhline(y=probs[idx[-(i+1)]], linestyle=":", lw=0.5, color="0.5")
  y = probs*1.
  y[idx[-(i+1):]] = 0
  plt.plot(y / y.sum(), linestyle="--", marker="o", label=f"{i=}")
plt.legend()

for y in rankings_prob:
  plt.plot(y, linestyle="-.", marker="x")

In [None]:
sns.heatmap(rank_prob_matrix, square=True, annot=True, fmt=".2f")
# plt.colorbar()

In [None]:
sns.heatmap(rank_prob_matrix / rank_prob_matrix.max(axis=1, keepdims=True), square=True, annot=True, fmt=".2f")
# plt.colorbar()

In [None]:
logits.argsort()

In [None]:
rank_prob_matrix.sum(axis=1, keepdims=True)

In [None]:
log_scores_tiled = np.tile(log_scores[None,:], (n_samples, 1))
log_scores_tiled

In [None]:
log_scores_tiled + 18 - np.amax(log_scores_tiled, axis=1, keepdims=True)

In [None]:
np.log(np.sum(np.exp(log_scores_tiled + 18 - np.amax(log_scores_tiled, axis=1, keepdims=True)), axis=1, keepdims=True))

In [None]:
logsumexp(log_scores_tiled + 18 - np.amax(log_scores_tiled, axis=1, keepdims=True), axis=1, keepdims=True)

In [None]:
np.exp((
    log_scores_tiled + 18 - np.amax(log_scores_tiled, axis=1, keepdims=True)
) - logsumexp(
    log_scores_tiled + 18 - np.amax(log_scores_tiled, axis=1, keepdims=True), axis=1, keepdims=True
))

In [None]:
softmax(log_scores_tiled + 18 - np.amax(log_scores_tiled, axis=1, keepdims=True), axis=1)

In [None]:
ind = np.arange(n_samples)
for i in range(3):
  log_scores_tiled += 18 - np.amax(log_scores_tiled, axis=1, keepdims=True)
  probs = softmax(log_scores_tiled, axis=1)
  print(f"{i=}")
  plt.figure()
  sns.heatmap(probs, square=True, annot=True, fmt=".2f")
  ###
  cumprobs = np.cumsum(probs, axis=1)
  random_values = np.random.uniform(size=n_samples)
  greater_equal_mask = np.greater_equal(random_values[:,None], cumprobs)
  sampled_ind = np.sum(greater_equal_mask, axis=1)

  # rankings[:, i] = sampled_ind
  # inv_rankings[ind, sampled_ind] = i
  # rankings_prob[:, i] = probs[ind, sampled_ind]
  log_scores_tiled[ind, sampled_ind] = np.NINF

In [None]:
softmax(log_scores_tiled, axis=1)

In [None]:
log_scores_tiled += 18 - np.amax(log_scores_tiled, axis=1)[:, None]
log_denom = np.log(np.sum(np.exp(log_scores), axis=1))
probs = np.exp(log_scores - log_denom[:, None])
# if prob_per_rank:
#   rank_prob_matrix[i, :] = np.mean(probs, axis=0)
cumprobs = np.cumsum(probs, axis=1)
random_values = np.random.uniform(size=n_samples)
greater_equal_mask = np.greater_equal(random_values[:,None], cumprobs)
sampled_ind = np.sum(greater_equal_mask, axis=1)

rankings[:, i] = sampled_ind
inv_rankings[ind, sampled_ind] = i
rankings_prob[:, i] = probs[ind, sampled_ind]
log_scores[ind, sampled_ind] = np.NINF