In [1]:
# https://github.com/Felix-Petersen/diffsort

import math
from typing import List, Tuple

import torch
import numpy as np

def odd_even_network(n):
    layers = n

    network = []

    shifted: bool = False
    even: bool = n % 2 == 0

    for layer in range(layers):

        if even:
            k = n // 2 + shifted
        else:
            k = n // 2 + 1

        split_a, split_b = np.zeros((k, n)), np.zeros((k, n))
        combine_min, combine_max = np.zeros((n, k)), np.zeros((n, k))

        count = 0

        # for i in range(n // 2 if not (even and shifted) else n // 2 - 1):
        for i in range(int(shifted), n-1, 2):
            a, b = i, i + 1
            split_a[count, a], split_b[count, b] = 1, 1
            combine_min[a, count], combine_max[b, count] = 1, 1
            count += 1

        if even and shifted:
            # Make sure that the corner values stay where they are/were:
            a, b = 0, 0
            split_a[count, a], split_b[count, b] = 1, 1
            combine_min[a, count], combine_max[b, count] = .5, .5
            count += 1
            a, b = n - 1, n - 1
            split_a[count, a], split_b[count, b] = 1, 1
            combine_min[a, count], combine_max[b, count] = .5, .5
            count += 1

        elif not even:
            if shifted:
                a, b = 0, 0
            else:
                a, b = n - 1, n - 1
            split_a[count, a], split_b[count, b] = 1, 1
            combine_min[a, count], combine_max[b, count] = .5, .5
            count += 1

        assert count == k

        network.append((split_a, split_b, combine_min, combine_max))
        shifted = not shifted

    return network


def get_sorting_network(n, device='cpu'):
    def matrix_to_torch(m):
        return [[torch.from_numpy(matrix).float().to(device) for matrix in matrix_set] for matrix_set in m]
    return matrix_to_torch(odd_even_network(n))

In [2]:
ls = get_sorting_network(4)
len(ls)

4

In [3]:
ls[0], ls[0][2].shape

([tensor([[1., 0., 0., 0.],
          [0., 0., 1., 0.]]),
  tensor([[0., 1., 0., 0.],
          [0., 0., 0., 1.]]),
  tensor([[1., 0.],
          [0., 0.],
          [0., 1.],
          [0., 0.]]),
  tensor([[0., 0.],
          [1., 0.],
          [0., 0.],
          [0., 1.]])],
 torch.Size([4, 2]))

In [4]:
vectors = torch.tensor([[3, 2, 4, 1]], dtype=torch.float32)
steepness=10.

x = vectors
X = torch.eye(vectors.shape[1], dtype=x.dtype, device=x.device).repeat(x.shape[0], 1, 1)

split_a, split_b, combine_min, combine_max = ls[0]
split_a = split_a.type(x.dtype)
split_b = split_b.type(x.dtype)
combine_min = combine_min.type(x.dtype)
combine_max = combine_max.type(x.dtype)

a, b = x @ split_a.T, x @ split_b.T

# float conversion necessary as PyTorch doesn't support Half for sigmoid as of 25. August 2021
new_type = torch.float32 if x.dtype == torch.float16 else x.dtype


alpha = torch.sigmoid((b-a).type(new_type) * steepness).type(x.dtype)


aX = X @ split_a.T
bX = X @ split_b.T
w_min = alpha.unsqueeze(-2) * aX + (1-alpha).unsqueeze(-2) * bX
w_max = (1-alpha).unsqueeze(-2) * aX + alpha.unsqueeze(-2) * bX
X = (w_max @ combine_max.T.unsqueeze(-3)) + (w_min @ combine_min.T.unsqueeze(-3))
x = (alpha * a + (1-alpha) * b) @ combine_min.T + ((1-alpha) * a + alpha * b) @ combine_max.T

In [5]:
X[0]@vectors[0]

tensor([2.0000, 3.0000, 1.0000, 4.0000])

In [6]:
X

tensor([[[4.5398e-05, 9.9995e-01, 0.0000e+00, 0.0000e+00],
         [9.9995e-01, 4.5398e-05, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 9.3576e-14, 1.0000e+00],
         [0.0000e+00, 0.0000e+00, 1.0000e+00, 9.3576e-14]]])

In [7]:
import pandas as pd

In [8]:
df = pd.read_csv("/Users/ashibaga/ghq/ghe.corp.yahoo.co.jp/exodus/exodus/research/dd_quality/2024H2/random_202408/queries.tsv", sep="\t")
df = df.query("DD種別=='正解型'&領域=='一般'").fillna('なし').rename({"正解gid": "gids"},axis=1)[["query", "gids"]]

df["context"] = "なし"
df["is_closed"] = "なし"
df["google"] = "spot"
df["either_is_fine"] = 0

FileNotFoundError: [Errno 2] No such file or directory: '/Users/ashibaga/ghq/ghe.corp.yahoo.co.jp/exodus/exodus/research/dd_quality/2024H2/random_202408/queries.tsv'

In [46]:
df.to_csv("/Users/ashibaga/ghq/ghe.corp.yahoo.co.jp/exodus/exodus/research/dd_quality/2024H2/generalspot_20240807/queries_r2k.tsv", sep="\t", index=False)