In [30]:
from scipy.spatial.distance import cdist
from scipy.stats import mode
import numpy as np
from sklearn.svm import LinearSVC
from sklearn.linear_model import SGDClassifier
from itertools import permutations, combinations

from numpy.random import default_rng
from numpy.linalg import norm
from scipy.special import softmax
from collections import Counter
from pathlib import Path

def normalize(x):
    return (x - x.mean(axis=0))/x.std(axis=0)

def get_features(path, norm=True):
    data = np.load(path)
    ss, st = data['s'], data['t']
    ssx, ssy = ss[:, :-1], ss[:, -1]
    stx, sty = st[:, :-1], st[:, -1]
    if norm:
        return normalize(ssx), ssy, normalize(stx), sty
    return ssx, ssy, stx, sty

def prototype_classifier(X, C, th=1):
    dist = cdist(C, X)
    prob = softmax(-dist, axis=0)
    return prob.argsort(axis=0)[-th:], prob

def masked_prototypical_classifier(X, C, ratio=0.3, seed=2437):
    rng = np.random.default_rng(seed)
    masked_idx = ~rng.binomial(1, ratio, X.shape[1]).astype(bool)
    return prototype_classifier(X[:, masked_idx], C[:, masked_idx])

In [3]:
data_path = Path('/work/chu980802/data/OfficeHome/Art')
class_name = {i: x.name for i, x in enumerate(sorted(data_path.iterdir()))}

#### Protonet Pseudo Label Training Space

In [None]:
avg = 0

for s, t in permutations(range(4), 2):
    s_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
    ssx, ssy, stx, sty = get_features(s_path, norm=False)

print(f'Avg acc.:', avg/12)

#### Protonet Domain Center

In [14]:
avg = 0

for s, t in permutations(range(4), 2):
    s_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
    ssx, ssy, stx, sty = get_features(s_path, norm=False)

    sc = ssx.mean(axis=0)
    tc = stx.mean(axis=0)
    c = np.stack([sc, tc])
    
    y = np.r_[np.zeros(len(ssy)), np.ones(len(sty))]
    x = np.r_[ssx, stx]
    pred, _ = prototype_classifier(x, c)
    
    score = (pred == y).mean()
    print('-'*10, f'source {s}, target {t}', '-'*10)
    print(score)
    dist = np.linalg.norm(sc - tc)
    print(dist)
    avg += score
print(f'Avg acc.:', avg/12)

---------- source 0, target 1 ----------
0.8002101471029721
4.166670686551401
---------- source 0, target 2 ----------
0.754750593824228
3.277573206803851
---------- source 0, target 3 ----------
0.6526901112113015
1.8081951966668202
---------- source 1, target 0 ----------
0.7718402882017412
3.6082645349661986
---------- source 1, target 2 ----------
0.7120129121512566
2.213101755121368
---------- source 1, target 3 ----------
0.7162476722532588
2.6389116816683
---------- source 2, target 0 ----------
0.7826603325415677
3.570035466838109
---------- source 2, target 1 ----------
0.7870647913304127
3.5365292267186796
---------- source 2, target 3 ----------
0.6655896607431341
1.9173452956243544
---------- source 3, target 0 ----------
0.689209498046288
2.2194521074202824
---------- source 3, target 1 ----------
0.8157588454376163
4.675483083514556
---------- source 3, target 2 ----------
0.6772444034156474
1.8055738841796143
Avg acc.: 0.7354399380216186


In [29]:
avg = 0
for s, t in permutations(range(4), 2):
    s_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
    ssx, ssy, stx, sty = get_features(s_path, norm=False)

    ssc = np.stack([ssx[ssy == i].mean(axis = 0) for i in range(65)])
    pred, _ = prototype_classifier(stx, ssc)
#     c1 = LinearSVC(random_state=2476, C=0.01, max_iter=5000)
#     labels = np.arange(65)
#     c1.fit(ssc, labels)
#     pred = c1.predict(stx)
    score = (pred == sty).mean()
    print('-'*10, f'source {s}, target {t}', '-'*10)
    print(score)
    avg += score
print(f'C = {c}, Avg acc.:', avg/12)

---------- source 0, target 1 ----------
0.5130232558139535
---------- source 0, target 2 ----------
0.6801554641060814
---------- source 0, target 3 ----------
0.7562907735321528
---------- source 1, target 0 ----------
0.5381033022861982
---------- source 1, target 2 ----------
0.6332876085962506
---------- source 1, target 3 ----------
0.6498136067101584
---------- source 2, target 0 ----------
0.5537679932260796
---------- source 2, target 1 ----------
0.46069767441860465
---------- source 2, target 3 ----------
0.7467381174277726
---------- source 3, target 0 ----------
0.6583403895004234
---------- source 3, target 1 ----------
0.5076744186046511
---------- source 3, target 2 ----------
0.7960676726108825
C = 10, Avg acc.: 0.6244966897361007


#### parameter selection on source

In [17]:
# rng = np.random.default_rng(12345)
c_list = [1e-4, 1e-3, 1e-2, 1e-1, 1, 10]
total_avg = []
for c in c_list:
    avg = 0
    for s, t in permutations(range(4), 2):
        s_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
        ssx, ssy, stx, sty = get_features(s_path, norm=False)

        ssc = np.stack([ssx[ssy == i].mean(axis = 0) for i in range(65)])
        c1 = LinearSVC(random_state=2476, C=c, max_iter=5000)
        labels = np.arange(65)
        c1.fit(ssx, ssy)
        pred = c1.predict(stx)
        score = (pred == sty).mean()
#         print('-'*10, f'source {s}, target {t}', '-'*10)
#         print(score)
        avg += score
    print(f'C = {c}, Avg acc.:', avg/12)
    total_avg.append(avg/12)
print('Total Avg acc.:', np.mean(total_avg))
print('std:', np.std(total_avg))

C = 0.0001, Avg acc.: 0.6034702594687053
C = 0.001, Avg acc.: 0.6164105715510758
C = 0.01, Avg acc.: 0.6168803658897942
C = 0.1, Avg acc.: 0.5992471467409616




C = 1, Avg acc.: 0.5822636667859786




C = 10, Avg acc.: 0.5617626434076917
Total Avg acc.: 0.5966724423073679
std: 0.019481188421463846




#### Protonet classifier for source on target

In [8]:
avg = 0
for s, t in permutations(range(4), 2):
    s_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
    ssx, ssy, stx, sty = get_features(s_path, norm=False)

    ssc = np.stack([ssx[ssy == i].mean(axis=0) for i in range(65)])
    for i in range(1000):
        pred1, _ = masked_prototypical_classifier(stx, ssc, ratio=0.3, seed=i)
        
        print(pred1)
    break
#     score = (pred == tsy).mean()
#     print('-'*10, f'source {s}, target {t}', '-'*10)
#     print(score)
#     avg += score
print(f'Avg acc.:', avg/12)

[[18 15  0 ... 56 64 56]]
[[23 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[50 22  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[11 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18  5  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 55  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 62  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[50 62  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 62  0 .

[[18 37  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18  2  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[50 19  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 62  0 ... 56 64 56]]
[[11 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 15  0 .

[[50 37  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[50  5  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[18 62  0 ... 56 64 56]]
[[50 17  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[18 55  0 ... 56 64 56]]
[[50 17  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[23 37  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[50  9  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[50 17  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 15  0 .

[[18 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[44 37  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[50 37  0 ... 56 64 56]]
[[50 17  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[50 17  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 17  0 ... 56 64 56]]
[[50 15  0 ... 56 64 56]]
[[18 37  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 19  0 ... 56 64 56]]
[[18 15  0 ... 56 64 56]]
[[18 19  0 .

#### Hyper parameter selection

In [18]:
# rng = np.random.default_rng(12345)
c_list = [1e-4, 1e-3, 1e-2, 1e-1, 1, 10]
total_avg = []
ratio = 0.2
for c in c_list:
    avg = 0
    for s, t in permutations(range(4), 2):
        t_path = f'./OfficeHome/partial/s{t}_t{s}_{ratio}.npz'
        ttx, tty, tsx, tsy = get_features(t_path, norm=False)

        tsc = np.stack([tsx[tsy == i].mean(axis=0) for i in range(65)])
        labels = np.arange(65)

        c1 = LinearSVC(random_state=12453, C=c, max_iter=5000)
        c1.fit(tsc, labels)
#         c1.fit(tsx, tsy)
#         print('-'*10, f'source {s}, target {t}', '-'*10)
        score = c1.score(ttx, tty)
#         print(score)
        avg += score
    print(f'C = {c}, Avg acc.:', avg/12)
    total_avg.append(avg/12)
print('Total Avg acc.:', np.mean(total_avg))
print('std:', np.std(total_avg))

C = 0.0001, Avg acc.: 0.6989127597503392
C = 0.001, Avg acc.: 0.7158017242959386
C = 0.01, Avg acc.: 0.742664405817068
C = 0.1, Avg acc.: 0.732535106388521
C = 1, Avg acc.: 0.6997192937689464
C = 10, Avg acc.: 0.6904202812537279
Total Avg acc.: 0.7133422618790902
std: 0.018944723753003462


#### Protonet for partial space

In [14]:
avg = 0
ratio = 0.2
for s, t in permutations(range(4), 2):
    t_path = f'./OfficeHome/partial/s{t}_t{s}_{ratio}.npz'
    ttx, tty, tsx, tsy = get_features(t_path, norm=False)
    tsc = np.stack([tsx[tsy == i].mean(axis=0) for i in range(65)])    
    pred, _ = prototype_classifier(ttx, tsc)
    score = (pred == tty).mean()
    print('-'*10, f'source {s}, target {t}', '-'*10)
    print(score)
    avg += score
print('Avg acc.:', avg/12)

---------- source 0, target 1 ----------
0.6569767441860465
---------- source 0, target 2 ----------
0.782350251486054
---------- source 0, target 3 ----------
0.8215284249767009
---------- source 1, target 0 ----------
0.6845893310753599
---------- source 1, target 2 ----------
0.7508001828989483
---------- source 1, target 3 ----------
0.7632805219012115
---------- source 2, target 0 ----------
0.6613039796782387
---------- source 2, target 1 ----------
0.6690697674418604
---------- source 2, target 3 ----------
0.820363466915191
---------- source 3, target 0 ----------
0.72692633361558
---------- source 3, target 1 ----------
0.6937209302325581
---------- source 3, target 2 ----------
0.8454503886602652
Avg acc.: 0.7396966935890013


#### top-2 accuracy

In [74]:
avg = 0
ratio = 0.2
for s, t in permutations(range(4), 2):
    s_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
    ssx, ssy, stx, sty = get_features(s_path, norm=False)
    ssc = np.stack([ssx[ssy == i].mean(axis=0) for i in range(65)])    

    pred, _ = prototype_classifier(stx, ssc, th=2)
    for i in range(65):
        p = pred.T[(pred.T[:, 0] == i) & (sty == i)]
        c = Counter(p[:, 1])
        print([(class_name[a], b) for a, b in c.most_common(3)])
    score = np.array([True if y in p else False for y, p in zip(sty, pred.T)]).mean()
#     score = (pred == tty).mean()
    print('-'*10, f'source {s}, target {t}', '-'*10)
    print(score)
    avg += score
    break
print('Avg acc.:', avg/12)

[('Exit_Sign', 4), ('Chair', 1), ('Folder', 1)]
[('Chair', 3), ('Alarm_Clock', 1), ('Flipflops', 1)]
[('Trash_Can', 5), ('Exit_Sign', 1), ('Folder', 1)]
[('Couch', 4), ('Folder', 2), ('Postit_Notes', 2)]
[('Chair', 2), ('Paper_Clip', 2)]
[('Soda', 7), ('Clipboards', 3), ('Bucket', 2)]
[('Trash_Can', 16), ('Scissors', 1), ('Mop', 1)]
[('Keyboard', 1), ('Folder', 1)]
[('Calculator', 7), ('Bucket', 1)]
[('Sneakers', 2), ('Bucket', 2), ('Push_Pin', 1)]
[('Table', 8), ('Couch', 2)]
[('Folder', 2), ('Exit_Sign', 1), ('Notebook', 1)]
[('Monitor', 7), ('Laptop', 5), ('Calculator', 1)]
[('Bed', 3), ('Folder', 2), ('Chair', 1)]
[('Mop', 1), ('Trash_Can', 1)]
[('Lamp_Shade', 18), ('Pan', 1)]
[('Hammer', 3), ('Screwdriver', 3), ('Exit_Sign', 1)]
[('Folder', 1), ('Batteries', 1), ('Pencil', 1)]
[]
[('Desk_Lamp', 2), ('Pan', 1)]
[('Shelf', 1), ('Calendar', 1), ('Folder', 1)]
[('Trash_Can', 1), ('Sneakers', 1), ('Soda', 1)]
[('Fan', 7), ('Glasses', 2), ('Fork', 2)]
[('Exit_Sign', 7), ('Clipboards', 4