In [1]:
%cd ..

f:\DS Lab\OT\ot-kpgg-fc


In [2]:
import numpy as np
import os
import ot
import scipy.io as sio
import torch
from sklearn.svm import SVC
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

from optimal_transport.models import KeypointFOT

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.pylab as pl
import ot.plot
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

import warnings
warnings.filterwarnings('ignore')

from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
  register_backend(TensorflowBackend())


In [3]:
def class_balancing(x, t, num=10, seed=0):
    t = np.squeeze(t)
    x_class = []
    n_class = t.max()
    for k in range(1, n_class + 1, 1):
        x_class.append(x[t == k])
    np.random.seed(seed)
    x_class_balance = [xx[np.random.choice(np.arange(len(xx)), num, replace=True)] for xx in x_class]
    t_class_balance = [np.ones(num, dtype=np.int32) * k for k in range(1, n_class + 1, 1)]
    x_class_balance = np.vstack(x_class_balance)
    t_class_balance = np.hstack(t_class_balance)
    return x_class_balance, t_class_balance

In [4]:
def load_data(source, target, num_labeled=1, class_balance=True, num=10, seed=0):
    source_data = sio.loadmat("data/hda/decaf/{}_fc6.mat".format(source))
    target_data = sio.loadmat("data/hda/resnet50/{}.mat".format(target))

    source_feat, source_label = source_data["fts"], source_data["labels"]
    target_feat, target_label = target_data["fts"], target_data["labels"]
    source_label, target_label = source_label.reshape(-1, ), target_label.reshape(-1, )

    indexes = sio.loadmat("data/hda/labeled_index/{}_{}.mat".format(target,num_labeled))
    idx_labeled,idx_unlabeled = indexes["labeled_index"][0], indexes["unlabeled_index"][0]

    target_feat_l, target_label_l = target_feat[idx_labeled], target_label[idx_labeled]
    target_feat_un, target_label_un = target_feat[idx_unlabeled], target_label[idx_unlabeled]

    if class_balance:
        source_feat, source_label = class_balancing(source_feat, source_label, num=num,seed=seed)
        target_feat_un, target_label_un = class_balancing(target_feat_un, target_label_un, num=num)

    return source_feat, source_label, target_feat_l, target_label_l, target_feat_un, target_label_un

In [5]:
domains = ["amazon", "dslr", "webcam"]
num_labeled = 1
seed = 1

Tasks = []
Accs = []

In [6]:
for source in domains:
    for target in domains:
        
        print("source:{} --> target:{}".format(source, target))
        
        feat_s, label_s, feat_tl, label_tl, feat_tu, label_tu = load_data(source,target,num_labeled,seed=seed)
        # print(feat_s.shape)
        # print(label_tl.shape)
        # keypoints
        n_keypoints = 31
        I = []
        J = []
        t = 0
        feat_sl = []

        
        for l in label_tl:
            I.append(t)
            J.append(t)
            fl = feat_s[label_s==l]
            feat_sl.append(np.mean(fl, axis=0))
            t += 1
            
        feat_sl = np.vstack(feat_sl)
        label_s_ = np.concatenate((label_tl, label_s))
        feat_s_ = np.vstack((feat_sl, feat_s))
        
        feat_t_ = np.vstack((feat_tl, feat_tu))
        
        p = np.ones(len(feat_s_))/len(feat_s_)
        q = np.ones(len(feat_t_))/len(feat_t_)
        K = list(zip(I, J))
        
        model = KeypointFOT(label_s_, n_free_anchors=n_keypoints, alpha=0.5, stop_thr=1e-5,
                               sinkhorn_reg=0.005, temperature=0.1, div_term=1e-10, max_iters=20, n_clusters=31)

        result = model.fit(feat_s_, feat_t_, p, q, K)
        Px = result.Pa_
        Py = result.Pb_
        P = Px.dot(Py)
        M = result.z_
        # print(M)
        feat_s_trans = P@feat_t_/p.reshape(-1,1)
        feat_train = np.vstack((feat_tl,feat_s_trans[len(feat_tl):]))
        label_train = np.hstack((label_tl,label_s))

        print("train svm...")
        clf = SVC(gamma='auto',probability=True)
        clf.fit(feat_train,label_train)
        acc = clf.score(feat_tu,label_tu)

        # print("SVM_ACC:{:.2f}\n\n".format(acc*100))

        Tasks.append(source[0].upper()+"2"+target[0].upper())
        Accs.append(round(acc*100,2))
Tasks.append("avg")
Accs.append(round(np.mean(np.array(Accs)),2))
print("task:\tacc")
for k in range(len(Tasks)):
    print("{:}:\t{:.2f}".format(Tasks[k],Accs[k]))
            

source:amazon --> target:amazon
inital z:  [[-37.300606    -2.0023358   -2.5098934  ...  -7.931305   -37.836494
  -25.38333   ]
 [-34.860893    -3.0928895   -0.89808255 ...  -6.9680924  -26.88684
  -24.823849  ]
 [ -4.1532006    6.455011    -8.671114   ...  -6.4367514  -33.522392
  -25.649689  ]
 ...
 [-10.267283    -2.8322256  -12.2870455  ... -23.858223   -34.407623
    6.583577  ]
 [-21.668589    -3.0575736  -20.682615   ... -14.8679     -21.98196
   -7.8991456 ]
 [-30.451729    -6.4366217   -4.031206   ... -12.136512   -27.422083
   29.12361   ]]
[[-18.34748877  -1.52224882   1.43434421 ...  -1.60520291 -20.25070074
  -13.75643358]
 [-18.65029347  -1.00115848  -1.25493729 ...  -3.96564305 -18.9182378
  -12.69165528]
 [ -1.7961991    3.92184758  -4.60209266 ...  -3.56478209 -16.74255339
  -13.09256156]
 ...
 [-11.41604711  -0.28290354  -5.32045175 ...  -6.682961   -14.6607113
   -6.44259073]
 [-12.9490743   -1.55036856  -8.48408833 ...  -5.14465429 -10.32182523
   -7.23005752]
 [-10