In [73]:
from scipy.spatial.distance import cdist
from scipy.stats import mode
import numpy as np
from sklearn.svm import LinearSVC
from sklearn.linear_model import SGDClassifier
from itertools import permutations, combinations

from numpy.random import default_rng
from numpy.linalg import norm
from scipy.special import softmax
from sklearn.decomposition import PCA

def normalize(x):
    return (x - x.mean(axis=0))/x.std(axis=0)

def get_features(path, norm=True):
    data = np.load(path)
    ss, st = data['s'], data['t']
    ssx, ssy = ss[:, :-1], ss[:, -1]
    stx, sty = st[:, :-1], st[:, -1]
    if norm:
        return normalize(ssx), ssy, normalize(stx), sty.astype(int)
    return ssx, ssy, stx, sty.astype(int)

def prototype_classifier(X, C, th=2):
    dist = cdist(C, X)
    prob = softmax(-dist, axis=0)
    return prob.argsort(axis=0)[-th:, :], dist.sort(axis=0)[:th, :]

#### Guess the transformation

In [69]:
avg = 0
for s, t in permutations(range(4), 2):
    org_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
    new_path = f'./OfficeHome/s2t/s{s}_t{t}.npz'
    m = np.load(new_path)['m']
    
    osx, osy, otx, oty = get_features(org_path, norm=False)
    osc = np.stack([osx[osy == i].mean(axis=0) for i in range(65)])
print('Avg acc.:', avg/12)

Avg acc.: 0.0


#### After linear transformation

In [72]:
avg = 0
for s, t in permutations(range(4), 2):
    org_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
    new_path = f'./OfficeHome/s2t/s{s}_t{t}.npz'
    
    osx, osy, otx, oty = get_features(org_path, norm=False)
    m = np.load(new_path)['m']
    ntx = otx @ m.T
    nsx = osx @ m.T
    osc = np.stack([osx[osy == i].mean(axis=0) for i in range(65)])
    
    pred, _ = prototype_classifier(ntx, osc)
    score = (pred == oty).mean()
    print('-'*10, f'source {s}, target {t}', '-'*10)
    print(score)
    avg += score
print('Avg acc.:', avg/12)

---------- source 0, target 1 ----------
0.61
---------- source 0, target 2 ----------
0.7770919067215364
---------- source 0, target 3 ----------
0.7814538676607642
---------- source 1, target 0 ----------
0.6278577476714648
---------- source 1, target 2 ----------
0.7233653406492913
---------- source 1, target 3 ----------
0.7034016775396086
---------- source 2, target 0 ----------
0.6879762912785775
---------- source 2, target 1 ----------
0.5927906976744186
---------- source 2, target 3 ----------
0.7996272134203168
---------- source 3, target 0 ----------
0.7345469940728196
---------- source 3, target 1 ----------
0.6158139534883721
---------- source 3, target 2 ----------
0.8374485596707819
Avg acc.: 0.7076145208206626
