In [1]:
%load_ext autoreload
%autoreload 2
import os, pickle
import numpy as np
import torch
from algorithms.selection import skey, get_topk, get_candidates_with_label, tripet_greedy, nn_greedy

In [2]:
triplets = np.array(pickle.load(open("data/bm_triplets/3c2_unique=182/train_triplets.pkl", "rb")))
embeds = pickle.load(open("embeds/bm/human/MTL.BCETN_train_emb10.pkl", "rb"))
labels = pickle.load(open("/net/scratch/hanliu-shared/data/bm/embs/dwac_train_emb10.merged10.pkl", "rb"))[2]

In [3]:
tripet_greedy(embeds, 10, triplets, labels=labels, topk=10, verbose=False)

((3, 11, 17, 25, 26, 34, 56, 62, 95, 103), 41)

In [4]:
lpips_d_matrix = pickle.load(open("embeds/lpips/lpips.bm.train.pkl", "rb"))
lpips_triplets = pickle.load(open("data/bm_lpips_triplets/train_triplets.pkl", "rb"))
len(lpips_d_matrix), len(lpips_triplets)

(160, 669920)

In [5]:
tripet_greedy(embeds, 10, lpips_d_matrix, labels=labels, topk=10, verbose=False)

((21, 23, 33, 92, 99, 104, 106, 113, 131, 153), 11872)

In [6]:
tripet_greedy(embeds, 10, lpips_triplets, labels=labels, topk=10, verbose=False)

((70, 73, 75, 78, 79, 121, 126, 130, 146, 153), 7691)

Debugging

In [7]:
z = torch.tensor(embeds)
dist = torch.cdist(z, z).numpy()
uni = np.unique(triplets)
n = len(uni)

In [8]:
from collections import defaultdict
from copy import deepcopy
scores = defaultdict(lambda: 0)
visits = defaultdict(lambda: 0)

In [9]:
for t in triplets:
    a, p, n = t
    key = skey([p, n])
    if dist[a, p] <= dist[a, n]:      
        scores[key] += 1
    visits[key] += 1

In [10]:
m = 10
topk = 10
curr_scores = deepcopy(scores)
curr_visits = deepcopy(visits)
if labels is not None:
    cand_scores, cand_visits = get_candidates_with_label(scores, visits, labels)
    beam, w = get_topk(cand_scores, cand_visits, topk, metric="count", verbose=True)
else:
    beam, w = get_topk(curr_scores, curr_visits, topk, metric="count", verbose=True)
for i in range(3, m+1):
    new_scores = defaultdict(lambda: 0)
    new_visits = deepcopy(new_scores)
    for b in beam:
        for k in uni:
            if k not in b:
                key = skey(b + (k,))
                base_score = new_scores[key]
                base_visit = new_visits[key]
                if base_score == 0 and curr_scores[b] > 0:
                    new_scores[key] += curr_scores[b]
                if base_visit == 0 and curr_visits[b] > 0:
                    new_visits[key] += curr_visits[b]
                for j in b:
                    new_scores[key] += scores[skey([j, k])]
                    new_visits[key] += visits[skey([j, k])]
    curr_scores = new_scores
    curr_visits = new_visits
    beam, w = get_topk(curr_scores, curr_visits, topk, metric="count", verbose=True)

(56, 95) 2.0000 10
(56, 95, 103) 8.0000 10
(3, 56, 95, 103) 17.0000 10
(3, 34, 56, 95, 103) 21.0000 10
(3, 11, 34, 56, 95, 103) 24.0000 10
(3, 17, 25, 34, 56, 95, 103) 30.0000 10
(3, 11, 17, 25, 34, 56, 95, 103) 35.0000 10
(3, 11, 13, 17, 25, 34, 56, 95, 103) 37.0000 10
(3, 11, 17, 25, 26, 34, 56, 62, 95, 103) 41.0000 10


1NN Acc

In [11]:
y_train = pickle.load(open("/net/scratch/hanliu-shared/data/bm/embs/dwac_train_emb10.merged10.pkl", "rb"))[2]
y_valid = pickle.load(open("/net/scratch/hanliu-shared/data/bm/embs/dwac_valid_emb10.merged10.pkl", "rb"))[2]
train = pickle.load(open("embeds/bm/human/MTL.BCETN_train_emb10.pkl", "rb"))
valid = pickle.load(open("embeds/bm/human/MTL.BCETN_valid_emb10.pkl", "rb"))

In [12]:
nn_greedy((train, valid), 10, (y_train, y_valid), topk=10, metric="acc")

((1, 12, 19, 21, 23, 26, 30, 33, 50, 143), 1.0)

In [13]:
zt = torch.tensor(train)
zv = torch.tensor(valid)
dist = torch.cdist(zt, zv).numpy()
y1, y2 = y_train, y_valid
uni = np.arange(len(dist))
n = len(uni)
dist.shape

(160, 40)

In [14]:
from collections import defaultdict
from copy import deepcopy
scores = defaultdict(lambda: 0)
nearns = defaultdict(lambda: np.zeros(n, dtype=np.int8) - 1)

In [15]:
from itertools import combinations
for c in combinations(range(n), 2):
    key = skey(c)
    dsx = dist[key, :]
    nn1mask = np.argmin(dsx, axis=0)
    nn1 = np.take(key, nn1mask)
    nn1pred = np.take(y1, nn1)
    score = (nn1pred == y2).sum()
    scores[key] = score
    nearns[key] = nn1

In [16]:
m = 10
topk = 100
curr_scores = deepcopy(scores)
# curr_nearns = deepcopy(nearns)
beam, w = get_topk(curr_scores, defaultdict(lambda: len(y2)), topk, metric="acc", verbose=True)
for _ in range(3, m+1):
    new_scores = defaultdict(lambda: 0)
    # new_nearns = deepcopy(new_scores)
    for b in beam:
        for k in uni:
            if k not in b:
                key = skey(b + (k,))
                dsx = dist[key, :]
                nn1mask = np.argmin(dsx, axis=0)
                nn1 = np.take(key, nn1mask)
                nn1pred = np.take(y1, nn1)
                score = (nn1pred == y2).sum()
                new_scores[key] = score
                # new_nearns[key] = nn1
    curr_scores = new_scores
    # curr_nearns = new_nearns
    beam, w = get_topk(curr_scores, defaultdict(lambda: len(y2)), topk, metric="acc", verbose=True)

(50, 143) 1.0000 100
(1, 50, 143) 1.0000 100
(1, 12, 50, 143) 1.0000 100
(1, 12, 19, 50, 143) 1.0000 100
(1, 12, 19, 21, 50, 143) 1.0000 100
(1, 12, 19, 21, 23, 50, 143) 1.0000 100
(1, 12, 19, 21, 23, 26, 50, 143) 1.0000 100
(1, 12, 19, 21, 23, 26, 30, 50, 143) 1.0000 100
(1, 12, 19, 21, 23, 26, 30, 33, 50, 143) 1.0000 100


In [17]:
np.take(y_train, beam[0])

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1])