In [None]:
from scipy.spatial.distance import cdist
from scipy.stats import mode
import numpy as np
from sklearn.svm import LinearSVC
from sklearn.linear_model import SGDClassifier
from itertools import permutations, combinations

from numpy.random import default_rng
from numpy.linalg import norm
from scipy.special import softmax
from collections import Counter
from pathlib import Path
from sklearn.cluster import KMeans, kmeans_plusplus
import faiss
import torch

def normalize(x):
    return (x - x.mean(axis=0))/x.std(axis=0)

def get_features(path, labels=None, norm=True):
    data = np.load(path)
    if labels:
        s = data[labels]
        sx, sy = s[:, :-1], s[:, -1]
        return sx, sy
    ss, st = data['s'], data['t']
    ssx, ssy = ss[:, :-1], ss[:, -1]
    stx, sty = st[:, :-1], st[:, -1]
    if norm:
        return normalize(ssx), ssy, normalize(stx), sty
    return ssx, ssy, stx, sty

def prototype_classifier(X, C, th=1):
    dist = cdist(C, X)
    prob = softmax(-dist, axis=0)
#     prob = softmax(1/(1+dist), axis=0)
    return prob.argsort(axis=0)[-th:][::-1, :], prob

# def prototype_classifier(X, C):
#     dist = cdist(C, X)
#     prob = softmax(-dist, axis=0)
#     return prob.T

def masked_prototypical_classifier(X, C, ratio=0.3, seed=2437, th=2):
    rng = np.random.default_rng(seed)
    masked_idx = ~rng.binomial(1, ratio, X.shape[1]).astype(bool)
    return prototype_classifier(X[:, masked_idx], C[:, masked_idx], th=th)

def labeled_data_sampler(labels, shot=1, seed=1362):
    rng = np.random.default_rng(seed)
    size = len(np.unique(labels))
    idx = np.stack([rng.choice(np.where(labels == i)[0], shot) for i in range(size)]).flatten().astype(int)
    return idx, np.setdiff1d(np.arange(len(labels)), idx)

In [None]:
for s in range(4):
    s_path = f'./OfficeHome/source_only/s{s}_t{(s+1)%4}.npz'
    ssx, ssy, stx, sty = get_features(s_path, norm=False)

    ssc = np.stack([ssx[ssy == i].mean(axis=0) for i in range(65)])
    with open(f's{s}_center.npy', 'wb') as f:
        np.save(f, ssc)

In [None]:
seed = 2020

path = f'./OfficeHome/source_only/s0_t1_2020.npz'
sx, sy, tx, ty = get_features(path, norm=False)

l_idx, u_idx = labeled_data_sampler(ty, shot=3, seed=seed)
ltx, lty, utx, uty = tx[l_idx], ty[l_idx], tx[u_idx], ty[u_idx]


In [None]:
avg = 0
for s, t in permutations(range(4), 2):
    s_path = f'./OfficeHome/s2t_shot/s{s}_t{t}_{2024+s}.npz'
    ssx, ssy, stx, sty = get_features(s_path, norm=False)

    ssc = np.stack([ssx[ssy == i].mean(axis=0) for i in range(65)])
    stc = np.stack([stx[sty == i].mean(axis=0) for i in range(65)])

    pred, _ = prototype_classifier(stx, ssc)
    pred = pred.flatten()

#     if np.unique(pred).__len__() == 65:
#         pseudo_c = np.stack([stx[pred == i].mean(axis=0) for i in range(65)])
#         pred, _ = prototype_classifier(stx, pseudo_c)
    
    print('='*10, f'source {s}, target {t}', '='*10)
    score = (pred == sty).mean()
    avg += score
    print(score)
print('Avg. score:', avg / 12)

In [None]:
t_path = f'./OfficeHome/kmeans_source_only/s{0}_t{1}_2.npz'
s_path = f'./OfficeHome/s2t_shot/s{0}_t{1}.npz'

# ssx, ssy, stx, _ = get_features(t_path, norm=False)
# _, _, _, sty = get_features(s_path, norm=False)
ssx, ssy, stx, sty = get_features(s_path, norm=False)
ssc = np.stack([ssx[ssy == i].mean(axis=0) for i in range(65)])
# stc = np.stack([stx[sty == i].mean(axis=0) for i in range(65)])

# label_map, _ = prototype_classifier(stc, ssc)
# label_map = label_map.flatten()
# pred = label_map[sty.astype(int)]
# print(pred)
# print((pred == correct_sty).mean())
stx = stx.astype('float32')
ssx = ssx.astype('float32')
n_clusters = 65
# print(n_clusters)

seed = 1347
# pseudo_c = np.stack([stx[pred.flatten() == i].mean(axis=0) for i in range(65)])

In [None]:
centers, _ = kmeans_plusplus(stx, n_clusters=n_clusters, random_state=seed)

In [None]:
kmeans = faiss.Kmeans(stx.shape[1], n_clusters, niter=300, nredo=5, gpu=True, seed=seed)
kmeans.train(stx, init_centroids=ssc.astype('float32'))
t_pred = kmeans.index.search(stx, 1)[1].flatten()

In [None]:
centroids = kmeans.centroids
print(centroids.shape)

In [None]:
centroids_y = np.array([int(mode(sty[np.where(t_pred==i)[0]]).mode.item()) for i in range(n_clusters)])
print(centroids_y)

In [None]:
pred, _ = prototype_classifier(centroids, ssc)
pred = pred.flatten()

In [None]:
print(pred)

In [None]:
print((pred == centroids_y).mean())

In [None]:
centroids_map = pred.copy()
c_pred, _ = prototype_classifier(stx, centroids)
c_pred = c_pred.flatten()
new_pred = centroids_map[c_pred]

print((new_pred == sty).mean())

In [None]:
s_pred = kmeans.index.search(ssc.astype('float32'), 1)[1].flatten()

In [None]:
s_pred

In [None]:
cnt = 0

for i in range(n_clusters):
    idx = np.where(t_pred==i)[0]
    n = sty[idx]
    cnt += mode(n)[1][0]
print(cnt/len(stx))