In [1]:
import itertools

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

In [2]:
perms = list(map(lambda p: "".join(p), itertools.permutations("1234567")))
perm2id = {p: i for i, p in enumerate(perms)}

In [3]:
perms_arr = np.array([list(map(int, p)) for p in perms])
perms_arr.shape

(5040, 7)

In [4]:
perms_onehot = np.eye(7)[perms_arr-1, :].transpose(0, 2, 1)
assert np.allclose(perms_onehot[:,0,:].astype(np.int64), (perms_arr == 1).astype(np.int64))

print("onehot 1234567:")
print(perms_onehot[perm2id["1234567"]])

print("onehot 5671234:")
print(perms_onehot[perm2id["5671234"]])

print("correlate between 1234567 and 5671234")
left = perms_onehot[perm2id["1234567"]]
right = perms_onehot[perm2id["5671234"]]
matches = F.conv2d(
    F.pad(torch.Tensor(left[None, None, :, :]), (7, 7)),
    torch.Tensor(right[None, None, :, :]),
    padding="valid"
).numpy().reshape(-1)
print(matches)
must_match_left2right = np.array([-1, -1, -1, -1, -1, -1, -1, 7, 6, 5, 4, 3, 2, 1, 0])
must_match_right2left = np.array([0, 1, 2, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1])
cost_ifmatch = np.array([7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7])
print("cost of 1234567 -> 5671234:", min(cost_ifmatch[np.equal(must_match_left2right, matches)]))
print("cost of 5671234 -> 1234567:", min(cost_ifmatch[np.equal(must_match_right2left, matches)]))

onehot 1234567:
[[1. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 1.]]
onehot 5671234:
[[0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0.]]
correlate between 1234567 and 5671234
[0. 0. 0. 0. 4. 0. 0. 0. 0. 0. 0. 3. 0. 0. 0.]
cost of 1234567 -> 5671234: 4
cost of 5671234 -> 1234567: 3


In [5]:
M = F.conv2d(
    F.pad(torch.Tensor(perms_onehot[:, None, :, :]), (7, 7)),
    torch.Tensor(perms_onehot[:, None, :, :]),
    padding="valid"
).squeeze().numpy()

M.shape

(5040, 5040, 15)

In [6]:
must_match_left2right = np.array([-1, -1, -1, -1, -1, -1, -1, 7, 6, 5, 4, 3, 2, 1, 0])
must_match_left2right_wild = np.array([-1, -1, -1, -1, -1, -1, -1, 6, 5, 4, 3, 2, 1, 0, 0])

cost_ifmatch = np.array([7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7])

costMat = np.where(M == must_match_left2right, cost_ifmatch, np.inf).min(axis=-1).astype(np.int8)
costMatWild = np.minimum(costMat, np.where(M == must_match_left2right_wild, cost_ifmatch, np.inf).min(axis=-1)).astype(np.int8)


In [7]:
symbols = "üéÖü§∂ü¶åüßùüéÑüéÅüéÄ"
schedule = pd.read_csv("submission_tsp_6_perm_rebalance_no_wildcards_2218_2079_1947.csv").schedule.tolist()
words = [s.translate(str.maketrans(symbols, "1234567")) for s in schedule]

In [8]:
str.maketrans(symbols, "1234567")

{127877: 49,
 129334: 50,
 129420: 51,
 129501: 52,
 127876: 53,
 127873: 54,
 127872: 55}

In [9]:
list(map(len, words))

[2218, 2079, 1947]

In [10]:
nodes_list = []
table_list = []
for i in range(3):
    word = words[i]
    nodes = []
    for i in range(len(word)-6):
        p = word[i:i+7]
        if p in perm2id:
            nodes.append(perm2id[p])
    table = np.zeros((len(nodes), 10), np.int64)
    table[0, :] = 7
    for i in range(1, len(nodes)):
        e = costMat[nodes[i-1], nodes[i]]
        ew = costMatWild[nodes[i-1], nodes[i]]
        table[i,0] = table[i-1,0] + e
        table[i,1] = min(table[i-1,1] + e, table[i-1,0] + ew)
        table[i,2] = min(table[i-1,2], table[i-1,1]) + e # TODO: better transition
        table[i,3] = min(table[i-1,3], table[i-1,2]) + e
        table[i,4] = min(table[i-1,4], table[i-1,3]) + e
        table[i,5] = min(table[i-1,5], table[i-1,4]) + e
        table[i,6] = min(table[i-1,6], table[i-1,5]) + e
        table[i,7] = min(table[i-1,7], table[i-1,6]) + e
        table[i,8] = min(table[i-1,8], table[i-1,7]) + e
        table[i,9] = min(table[i-1,9] + e, table[i-1,8] + ew)
    print(table[-1].min(), table[-1])
    nodes_list.append(nodes)
    table_list.append(table)
    
# backtrack
new_words = []
wilds = []
for nodes, table in zip(nodes_list, table_list):
    ns = [perms[nodes[-1]]]
    track = np.argmin(table[-1])
    wild = []
    for i in range(len(nodes)-2, -1, -1):
        e = costMat[nodes[i], nodes[i+1]]
        ew = costMatWild[nodes[i], nodes[i+1]]
        if track == 0:
            ns.append(perms[nodes[i]][:e])
        elif track == 1:
            if table[i, 1] + e < table[i, 0] + ew:
                ns.append(perms[nodes[i]][:e])
            else:
                left = np.array(list(map(int, perms[nodes[i]][ew:])))
                right = np.array(list(map(int, perms[nodes[i+1]][:-ew])))
                mis = np.where(left != right)[0][0]
                wild.append(table[i, track-1]-7+ew+mis)
                ns.append(perms[nodes[i]][:ew])
                track = track - 1
        elif 2 <= track <= 8:
            if table[i, track] >= table[i, track-1]:
                track = track - 1
            ns.append(perms[nodes[i]][:e])
        elif track == 9:
            if table[i, 9] + e < table[i, 8] + ew:
                ns.append(perms[nodes[i]][:e])
            else:
                ns.append(perms[nodes[i]][:ew])
                left = np.array(list(map(int, perms[nodes[i]][ew:])))
                right = np.array(list(map(int, perms[nodes[i+1]][:-ew])))
                mis = np.where(left != right)[0][0]
                wild.append(table[i, track-1]-7+ew+mis)
                track = track - 1
        else:
            assert False
    assert track == 0
    wilds.append(wild)
    nsw = list("".join(ns[::-1]))
    for w in wild:
        nsw[w] = "*"
    new_words.append("".join(nsw))

2217 [2218 2217 2217 2217 2217 2217 2217 2217 2217 2217]
2077 [2079 2077 2077 2077 2077 2077 2077 2077 2077 2077]
1947 [1947 1947 1947 1947 1947 1947 1947 1947 1947 1947]


In [11]:
print("score: ", max(map(len, words)), "->", max(map(len, new_words)))

score:  2218 -> 2217


In [12]:
submission = pd.Series([a.translate(str.maketrans("1234567*", symbols+"üåü")) for a in new_words], name='schedule')
submission.to_csv('submission_tsp_6_perm_rebalance_2478_2478_2474_opt.csv', index=False)