In [None]:
from scipy.spatial.distance import cdist
from scipy.stats import mode
import numpy as np
from sklearn.svm import LinearSVC
from sklearn.linear_model import SGDClassifier
from itertools import permutations, combinations

from numpy.random import default_rng
from numpy.linalg import norm
from scipy.special import softmax
from collections import Counter
from pathlib import Path

def normalize(x):
    return (x - x.mean(axis=0))/x.std(axis=0)

def get_features(path, norm=True):
    data = np.load(path)
    ss, st = data['s'], data['t']
    ssx, ssy = ss[:, :-1], ss[:, -1]
    stx, sty = st[:, :-1], st[:, -1]
    if norm:
        return normalize(ssx), ssy, normalize(stx), sty
    return ssx, ssy, stx, sty

def prototype_classifier(X, C, th=1):
    dist = cdist(C, X)
    prob = softmax(-dist, axis=0)
    return prob.argsort(axis=0)[-th:], prob

def masked_prototypical_classifier(X, C, ratio=0.3, seed=2437):
    rng = np.random.default_rng(seed)
    masked_idx = ~rng.binomial(1, ratio, X.shape[1]).astype(bool)
    return prototype_classifier(X[:, masked_idx], C[:, masked_idx])

In [None]:
data_path = Path('/work/chu980802/data/OfficeHome/Art')
class_name = {i: x.name for i, x in enumerate(sorted(data_path.iterdir()))}

#### Protonet Pseudo Label Training Space

In [None]:
avg = 0

for s, t in permutations(range(4), 2):
    s_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
    ssx, ssy, stx, sty = get_features(s_path, norm=False)

print(f'Avg acc.:', avg/12)

#### Protonet Domain Center

In [None]:
avg = 0

for s, t in permutations(range(4), 2):
    s_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
    ssx, ssy, stx, sty = get_features(s_path, norm=False)

    sc = ssx.mean(axis=0)
    tc = stx.mean(axis=0)
    c = np.stack([sc, tc])
    
    y = np.r_[np.zeros(len(ssy)), np.ones(len(sty))]
    x = np.r_[ssx, stx]
    pred, _ = prototype_classifier(x, c)
    
    score = (pred == y).mean()
    print('-'*10, f'source {s}, target {t}', '-'*10)
    print(score)
    dist = np.linalg.norm(sc - tc)
    print(dist)
    avg += score
print(f'Avg acc.:', avg/12)

In [None]:
avg = 0
for s, t in permutations(range(4), 2):
    s_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
    ssx, ssy, stx, sty = get_features(s_path, norm=False)

    ssc = np.stack([ssx[ssy == i].mean(axis = 0) for i in range(65)])
    pred, _ = prototype_classifier(stx, ssc)
#     c1 = LinearSVC(random_state=2476, C=0.01, max_iter=5000)
#     labels = np.arange(65)
#     c1.fit(ssc, labels)
#     pred = c1.predict(stx)
    score = (pred == sty).mean()
    print('-'*10, f'source {s}, target {t}', '-'*10)
    print(score)
    avg += score
print(f'C = {c}, Avg acc.:', avg/12)

#### parameter selection on source

In [None]:
# rng = np.random.default_rng(12345)
c_list = [1e-4, 1e-3, 1e-2, 1e-1, 1, 10]
total_avg = []
for c in c_list:
    avg = 0
    for s, t in permutations(range(4), 2):
        s_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
        ssx, ssy, stx, sty = get_features(s_path, norm=False)

        ssc = np.stack([ssx[ssy == i].mean(axis = 0) for i in range(65)])
        c1 = LinearSVC(random_state=2476, C=c, max_iter=5000)
        labels = np.arange(65)
        c1.fit(ssx, ssy)
        pred = c1.predict(stx)
        score = (pred == sty).mean()
#         print('-'*10, f'source {s}, target {t}', '-'*10)
#         print(score)
        avg += score
    print(f'C = {c}, Avg acc.:', avg/12)
    total_avg.append(avg/12)
print('Total Avg acc.:', np.mean(total_avg))
print('std:', np.std(total_avg))

#### Protonet classifier for source on target

In [None]:
avg = 0
for s, t in permutations(range(4), 2):
    s_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
    ssx, ssy, stx, sty = get_features(s_path, norm=False)

    ssc = np.stack([ssx[ssy == i].mean(axis=0) for i in range(65)])
    for i in range(1000):
        pred1, _ = masked_prototypical_classifier(stx, ssc, ratio=0.3, seed=i)
        
        print(pred1)
    break
#     score = (pred == tsy).mean()
#     print('-'*10, f'source {s}, target {t}', '-'*10)
#     print(score)
#     avg += score
print(f'Avg acc.:', avg/12)

#### Hyper parameter selection

In [None]:
# rng = np.random.default_rng(12345)
c_list = [1e-4, 1e-3, 1e-2, 1e-1, 1, 10]
total_avg = []
ratio = 0.2
for c in c_list:
    avg = 0
    for s, t in permutations(range(4), 2):
        t_path = f'./OfficeHome/partial/s{t}_t{s}_{ratio}.npz'
        ttx, tty, tsx, tsy = get_features(t_path, norm=False)

        tsc = np.stack([tsx[tsy == i].mean(axis=0) for i in range(65)])
        labels = np.arange(65)

        c1 = LinearSVC(random_state=12453, C=c, max_iter=5000)
        c1.fit(tsc, labels)
#         c1.fit(tsx, tsy)
#         print('-'*10, f'source {s}, target {t}', '-'*10)
        score = c1.score(ttx, tty)
#         print(score)
        avg += score
    print(f'C = {c}, Avg acc.:', avg/12)
    total_avg.append(avg/12)
print('Total Avg acc.:', np.mean(total_avg))
print('std:', np.std(total_avg))

#### Protonet for partial space

In [None]:
avg = 0
ratio = 0.2
for s, t in permutations(range(4), 2):
    t_path = f'./OfficeHome/partial/s{t}_t{s}_{ratio}.npz'
    ttx, tty, tsx, tsy = get_features(t_path, norm=False)
    tsc = np.stack([tsx[tsy == i].mean(axis=0) for i in range(65)])    
    pred, _ = prototype_classifier(ttx, tsc)
    score = (pred == tty).mean()
    print('-'*10, f'source {s}, target {t}', '-'*10)
    print(score)
    avg += score
print('Avg acc.:', avg/12)

#### top-2 accuracy

In [None]:
avg = 0
ratio = 0.2
for s, t in permutations(range(4), 2):
    s_path = f'./OfficeHome/fixbi/s{s}_t{t}.npz'
    ssx, ssy, stx, sty = get_features(s_path, norm=False)
    ssc = np.stack([ssx[ssy == i].mean(axis=0) for i in range(65)])    

    pred, _ = prototype_classifier(stx, ssc, th=2)
    for i in range(65):
        p = pred.T[(pred.T[:, 0] == i) & (sty == i)]
        c = Counter(p[:, 1])
        print([(class_name[a], b) for a, b in c.most_common(3)])
    score = np.array([True if y in p else False for y, p in zip(sty, pred.T)]).mean()
#     score = (pred == tty).mean()
    print('-'*10, f'source {s}, target {t}', '-'*10)
    print(score)
    avg += score
    break
print('Avg acc.:', avg/12)