In [1]:
import  torch
import  numpy as np
from dataTST import SynDataMetaTST
from MetaTST import Meta
import argparse
parser = argparse.ArgumentParser()
from utils import MatConvert, TST_MMD_u, MMDu, TST_MMD_Multi

In [2]:
class myargs:
    def __init__(self):
        self.n=50
        self.n_te=150
        self.d=2
        self.K=10
        self.num_meta_tasks=100
        self.epoch=1000
        self.meta_lr=1e-2
        self.update_lr=0.8
        self.update_step=10
        self.closeness=0.3
args = myargs()

In [3]:
def mu_sigma(delta):
    Num_clusters = 2
    mu = np.zeros([Num_clusters, d])
    mu[1] = mu[1] + 0.5
    sigma = [np.identity(d), np.identity(d)]
    sigma[0][0, 1] = delta
    sigma[0][1, 0] = delta
    sigma[1][0, 1] = -delta
    sigma[1][1, 0] = -delta
    return mu, sigma

In [4]:
def gen_data(n, delta = 0.7, kk=0):
    mu_mx_1, sigma_mx_1 = mu_sigma(0)
    mu_mx_2, sigma_mx_2 = mu_sigma(delta)
    s1 = np.zeros([n * Num_clusters, d])
    s2 = np.zeros([n * Num_clusters, d])
    for i in range(Num_clusters):
        np.random.seed(seed=1102*kk + i + n)
        s1[n * (i):n * (i + 1), :] = np.random.multivariate_normal(mu_mx_1[i], sigma_mx_1[i], n)
    for i in range(Num_clusters):
        np.random.seed(seed=819*kk + 1 + i + n)
        s2[n * (i):n * (i + 1), :] = np.random.multivariate_normal(mu_mx_2[i], sigma_mx_2[i], n)
    S = np.concatenate((s1, s2), axis=0)
    return S, s1, s2

In [5]:
def gen_meta_data(num_meta_tasks, n, closeness):
    # generate meta-samples
    data_org = np.random.randn(num_meta_tasks,4*n,2)
    for nn in range(num_meta_tasks):
        delta= 0.6 - closeness + (0.1/num_meta_tasks) * (nn+1)
        data_org[nn] = gen_data(n,delta,nn)[0]
    
    return data_org

In [6]:
def alg_one(maml, db_train, epoch):
    for step in range(epoch):
        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)
        # train meta kernels
        J_value, model_u, sigma, sigma0_u, ep = maml(x_spt, y_spt, x_qry, y_qry) 
        # print objectives from epoch
        if step % 10 == 0:
            print('step:', step, '\ttraining J value:', J_value.item())
    return model_u, sigma, sigma0_u, ep

In [7]:
def get_no_trainable_tensors(model):
    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print('Total trainable tensors:', num)
    return num

In [8]:
def fine_tune(model, it, learning_rate, S, sigma, sigma0, ep, device, dtype):
    n_te = int(S.shape[0]/4)
    # setup optimizer for training deep kernel
    optimizer = torch.optim.Adam(list(model.parameters()) + [ep] + [sigma] + [sigma0],
                                   lr=learning_rate)
    J_star = np.zeros([it])
    for t in range(it):
        # one way to train kernel with limited data
        n_random = int(n_te*Num_clusters/5)
        selected_cls1 = np.random.choice(n_te * Num_clusters, n_random, False)
        selected_cls2 = np.random.choice(n_te * Num_clusters, n_random, False)
        s1_te_random = s1_te[selected_cls1,:]
        s2_te_random = s2_te[selected_cls2, :]
        S_random = np.concatenate((s1_te_random, s2_te_random), axis=0)
        S_random = MatConvert(S_random, device, dtype)
        # another way to train kernel with limited data (similar performance)
        # S_random = S
        # n_random = N1

        # Compute epsilon, sigma and sigma_0
        ep_ = ep ** 2
        sigma_ = sigma ** 2
        sigma0_ = sigma0 ** 2

        # Compute output of the deep network
        model_output = model(S_random)

        # Compute J (STAT_u)
        TEMP = MMDu(model_output, n_random, S_random, sigma_, sigma0_, ep_)
        mmd_value_temp = -1 * (TEMP[0] + 10 ** (-8))
        mmd_std_temp = torch.sqrt(TEMP[1] + 10 ** (-8))
        if mmd_std_temp.item() == 0:
            print('error!!')
        if np.isnan(mmd_std_temp.item()):
            print('error!!')
        STAT = torch.div(mmd_value_temp, mmd_std_temp)
        J_star[t] = STAT.item()

        # Initialize optimizer and Compute gradient
        optimizer.zero_grad()
        STAT.backward(retain_graph=True)

        # Update weights using gradient descent
        optimizer.step()
        # Print MMD, std of MMD and J
        if t % 100 == 0:
            print("mmd_value: ", -1 * mmd_value_temp.item(), "mmd_std: ", mmd_std_temp.item(), "Statistic: ",
                  -1 * STAT.item())
    return model, sigma, sigma0, ep

In [9]:
def DkTST(maml_init, epoch, learning_rate, S, device, dtype):
    _,model, sigma_init, sigma0_u_init, ep_init = maml_init.get_init()
    torch.manual_seed(1 * 19 + n)
    torch.cuda.manual_seed(kk * 19 + n)
    ep = MatConvert(np.ones(1) * np.sqrt(ep_init.detach().cpu().numpy()), device, dtype)
    ep.requires_grad = True
    sigma = MatConvert(np.ones(1) * np.sqrt(sigma_init.detach().cpu().numpy()), device, dtype)
    sigma.requires_grad = True
    sigma0 = MatConvert(np.ones(1) * np.sqrt(sigma0_u_init.detach().cpu().numpy()), device, dtype)
    sigma0.requires_grad = True
    # Setup optimizer for training init kernel
    optimizer = torch.optim.Adam(list(model.parameters()) + [ep] + [sigma] + [sigma0],
                               lr=learning_rate)
    J_star = np.zeros([N_epoch])
    N1 = int(S.shape[0]/2)
    for t in range(epoch):
        # Compute epsilon, sigma and sigma_0
        ep_ = ep ** 2
        sigma_ = sigma ** 2
        sigma0_ = sigma0 ** 2
        # Compute output of the deep network
        model_output = model(S)
        # Compute J (STAT_u)
        TEMP = MMDu(model_output, N1, S, sigma_, sigma0_, ep_)
        mmd_value_temp = -1 * (TEMP[0] + 10 ** (-8))
        mmd_std_temp = torch.sqrt(TEMP[1] + 10 ** (-8))
        if mmd_std_temp.item() == 0:
            print('error!!')
        if np.isnan(mmd_std_temp.item()):
            print('error!!')
        STAT = torch.div(mmd_value_temp, mmd_std_temp)
        # STAT_u = mmd_value_temp # D+M
        J_star[t] = STAT.item()
        # Initialize optimizer and Compute gradient
        optimizer.zero_grad()
        STAT.backward(retain_graph=True)
        # Update weights using gradient descent
        optimizer.step()
        # Print MMD, std of MMD and J
        if t % 100 == 0:
            print("mmd_value_init: ", -1 * mmd_value_temp.item(), "mmd_std_init: ", mmd_std_temp.item(), "Statistic_init: ",
                  -1 * STAT.item())
    return model, sigma, sigma0, ep

In [10]:
torch.manual_seed(222)
torch.cuda.manual_seed_all(222)
np.random.seed(222)

dtype = torch.float
device = torch.device("cuda:0")

d = args.d  # dimension of data
n = args.n  # number of samples in per mode
n_te = args.n_te # number of training samples for the target task
K = args.K  # number of trails
num_meta_tasks = args.num_meta_tasks # number of meta-samples
print('n: ' + str(n) + ' d: ' + str(d))

N_per = 100  # permutation times
alpha = 0.05  # test threshold
x_in = d  # number of neurons in the input layer, i.e., dimension of data
H = 30  # number of neurons in the hidden layer
x_out = 3 * d  # number of neurons in the output layer
learning_rate = 0.00005 # learning rate for MMD-D
N_epoch = 1000  # maximim number of epochs for training
N = 100  # # number of test sets
N_f = 100.0  # number of test sets (float)
list_nte = [50, 80, 100, 120, 150, 200, 250] # number of test samples for the target task

config = [
    ('linear', [H, x_in]),
    ('softplus', [True]),
    ('linear', [H, H]),
    ('softplus', [True]),
    ('linear', [H, H]),
    ('softplus', [True]),
    ('linear', [x_out, H]),
]

print(args)

# Generate variance and co-variance matrix of Q (target task)
Num_clusters = 2
mu_mx = np.zeros([Num_clusters, d])
mu_mx[1] = mu_mx[1] + 0.5
sigma_mx_1 = np.identity(d)
sigma_mx_2 = [np.identity(d), np.identity(d)]
sigma_mx_2[0][0, 1] = 0.7
sigma_mx_2[0][1, 0] = 0.7
sigma_mx_2[1][0, 1] = -0.7
sigma_mx_2[1][1, 0] = -0.7

mu_mx_s = np.zeros([Num_clusters, d])
mu_mx_s[1] = mu_mx_s[1] + 0.5
sigma_mx_1_s = np.identity(d)
sigma_mx_2_s = [np.identity(d), np.identity(d)]

# Naming variables
s1 = np.zeros([n * Num_clusters, d])
s2 = np.zeros([n * Num_clusters, d])
s1_te = np.zeros([n_te * Num_clusters, d])
s2_te = np.zeros([n_te * Num_clusters, d])
J_star_u = np.zeros([N_epoch])
J_star_u_init = np.zeros([N_epoch])
Results = np.zeros([len(list_nte), 2, K])

n: 50 d: 2
<__main__.myargs object at 0x7f5a43f09880>


In [11]:
kk=0
maml = Meta(args, config).to(device)
maml_init = Meta(args, config).to(device)

get_no_trainable_tensors(maml)
get_no_trainable_tensors(maml_init)

# generate meta-samples
data_org = np.random.randn(num_meta_tasks,4*n,2)
for nn in range(num_meta_tasks):
    sigma_mx_2_s[0][0, 1] = 0.6 - args.closeness + (0.1/num_meta_tasks) * (nn+1)
    sigma_mx_2_s[0][1, 0] = 0.6 - args.closeness + (0.1/num_meta_tasks) * (nn+1)
    sigma_mx_2_s[1][0, 1] = -(0.6 - args.closeness + (0.1/num_meta_tasks) * (nn+1))
    sigma_mx_2_s[1][1, 0] = -(0.6 - args.closeness + (0.1/num_meta_tasks) * (nn+1))

    for i in range(Num_clusters):
        np.random.seed(seed=1102*nn + i + n)
        s1[n * (i):n * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1_s, n)
    for i in range(Num_clusters):
        np.random.seed(seed=819*nn + 1 + i + n)
        s2[n * (i):n * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_2_s[i], n)
    data_org[nn] = np.concatenate((s1, s2), axis=0)
# get training loader for meta-samples
db_train = SynDataMetaTST(data_org, 10, 2, 150, 50)

model_u, sigma, sigma0_u, ep = alg_one(maml, db_train, args.epoch)

Total trainable tensors: 2136
Total trainable tensors: 2136
DB: train (80, 200, 2) test (20, 200, 2)
sigma: 0.20559516549110413 sigma0: 0.01671348512172699 epsilon: 2.767058942513399e-17
J_value: 0.010120658203959465
step: 0 	training J value: 0.010120658203959465
sigma: 0.1966267079114914 sigma0: 0.014227871783077717 epsilon: 5.245743182058504e-07
J_value: -0.004747361410409212
sigma: 0.20249001681804657 sigma0: 0.015475654043257236 epsilon: 6.66746636852622e-05
J_value: -0.0004521567316260189
sigma: 0.20700708031654358 sigma0: 0.016662057489156723 epsilon: 0.0002396562194917351
J_value: 0.02556450478732586
sigma: 0.21307934820652008 sigma0: 0.01764904148876667 epsilon: 0.00034253718331456184
J_value: 0.011304693296551704
sigma: 0.21772420406341553 sigma0: 0.019145991653203964 epsilon: 0.000581011117901653
J_value: -0.013781196437776089
sigma: 0.22256694734096527 sigma0: 0.02080332301557064 epsilon: 0.0009742077672854066
J_value: -0.011179816909134388
sigma: 0.22246505320072174 sigma0

In [12]:
# setup meta kernels
torch.manual_seed(1 * 19 + n)
torch.cuda.manual_seed(kk * 19 + n)
epsilonOPT = MatConvert(np.ones(1) * np.sqrt(ep.detach().cpu().numpy()), device, dtype)
epsilonOPT.requires_grad = True
sigmaOPT = MatConvert(np.ones(1) * np.sqrt(sigma.detach().cpu().numpy()), device, dtype)
sigmaOPT.requires_grad = True
sigma0OPT = MatConvert(np.ones(1) * np.sqrt(sigma0_u.detach().cpu().numpy()), device, dtype)
sigma0OPT.requires_grad = True
print(epsilonOPT.item())

# Generate training data for target tasks
# S = gen_data(n_te)[0]
for i in range(Num_clusters):
    np.random.seed(seed=1102*kk + i + n)
    s1_te[n_te * (i):n_te * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te)
for i in range(Num_clusters):
    np.random.seed(seed=819*kk + 1 + i + n)
    s2_te[n_te * (i):n_te * (i + 1), :] = np.random.multivariate_normal(mu_mx_s[i], sigma_mx_2[i], n_te)
S = np.concatenate((s1_te, s2_te), axis=0)
S = MatConvert(S, device, dtype)
N1 = Num_clusters*n_te

# Meta Kernels as init when training with training set from the target task
np.random.seed(seed=1102)
torch.manual_seed(1102)
torch.cuda.manual_seed(1102)
it=101
model_u, sigmaOPT, sigma0OPT, epsilonOPT = fine_tune(model_u, it, learning_rate/3, S, sigmaOPT,  sigma0OPT, epsilonOPT, device, dtype)

0.18957842886447906
mmd_value:  -0.0007402051705867052 mmd_std:  0.019663548097014427 Statistic:  -0.03764352202415466
mmd_value:  -0.0002810147125273943 mmd_std:  0.015200947411358356 Statistic:  -0.018486658111214638


In [13]:
#page 6
# random kernel as init when training with training set from the target task --
# --> validate the consistence of performance of MMD-D
# setup init Kernels 
# print(epsilonOPT_init.item())
epoch = 1000
_,model_u_init, sigma_init, sigma0_u_init, ep_init = maml_init.get_init()
model_u_init, sigmaOPT_init, sigma0OPT_init, epsilonOPT_init = DkTST(maml_init, epoch, learning_rate, S, device, dtype)

mmd_value_init:  0.0001281559088965878 mmd_std_init:  0.005799873266369104 Statistic_init:  0.02209632843732834
mmd_value_init:  0.0002845390699803829 mmd_std_init:  0.006900147534906864 Statistic_init:  0.04123666509985924
mmd_value_init:  0.0003856457769870758 mmd_std_init:  0.007440447807312012 Statistic_init:  0.05183098837733269
mmd_value_init:  0.0004607635783031583 mmd_std_init:  0.007640775293111801 Statistic_init:  0.06030324846506119
mmd_value_init:  0.0005179089494049549 mmd_std_init:  0.007686390075832605 Statistic_init:  0.06737999618053436
mmd_value_init:  0.0005584777100011706 mmd_std_init:  0.007642611861228943 Statistic_init:  0.07307419180870056
mmd_value_init:  0.0005837843054905534 mmd_std_init:  0.007472391240298748 Statistic_init:  0.0781254991889
mmd_value_init:  0.000594753073528409 mmd_std_init:  0.007140212692320347 Statistic_init:  0.0832962691783905
mmd_value_init:  0.0006005028262734413 mmd_std_init:  0.006705150939524174 Statistic_init:  0.0895584374666214

In [14]:
kk=0
# test the trained kernel on the target task (with different sample size: 50, 80, 100, 120, 150, 200, 250)
for i_test in range(len(list_nte)):
    n_te2 = list_nte[i_test]
    s1_te2 = np.zeros([n_te2 * Num_clusters, d])
    s2_te2 = np.zeros([n_te2 * Num_clusters, d])
    N1_te2 = Num_clusters * n_te2
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    H_u_init = np.zeros(N)
    T_u_init = np.zeros(N)
    M_u_init = np.zeros(N)
    np.random.seed(1102)
    count_u = 0
    count_u_init = 0
    for k in range(N):
        # Generate target tasks
        for i in range(Num_clusters):
            np.random.seed(seed=1102 * (k+2) + 2*kk + i + n)
            s1_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        for i in range(Num_clusters):
            np.random.seed(seed=819 * (k + 1) + 2*kk + i + n)
            s2_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx_s[i], sigma_mx_2[i], n_te2)
        S = np.concatenate((s1_te2, s2_te2), axis=0)
        S = MatConvert(S, device, dtype)

        # Run two sample test (deep kernel) on generated data
        h_u, threshold_u, mmd_value_u = TST_MMD_u(model_u(S), N_per, N1_te2, S, sigma, sigma0_u, ep, alpha, device, dtype)
        h_u_init, threshold_u_init, mmd_value_u_init = TST_MMD_u(model_u_init(S), N_per, N1_te2, S, sigma_init, sigma0_u_init, ep_init, alpha, device, dtype)

        # Gather results
        count_u = count_u + h_u
        count_u_init = count_u_init + h_u_init
        print("Meta_KL:", count_u, "MMD-DK:", count_u_init)
        H_u[k] = h_u
        T_u[k] = threshold_u
        M_u[k] = mmd_value_u
        H_u_init[k] = h_u_init
        T_u_init[k] = threshold_u_init
        M_u_init[k] = mmd_value_u_init

    # Print test power of MetaKL and MMD-D
    print("Test Power of Meta MMD: ", H_u.sum() / N_f)
    Results[i_test, 0, kk] = H_u.sum() / N_f
    print("Test Power of Meta MMD (K times): ", Results[i_test, 0])
    print("Average Test Power of Meta MMD: ", Results[i_test, 0].sum() / (kk + 1))

    print("Test Power of deep MMD: ", H_u_init.sum() / N_f)
    Results[i_test, 1, kk] = H_u_init.sum() / N_f
    print("Test Power of deep MMD (K times): ", Results[i_test, 1])
    print("Average Test Power of deep MMD: ", Results[i_test, 1].sum() / (kk + 1))

print(Results[:,:,kk])

Meta_KL: 0 MMD-DK: 0
Meta_KL: 0 MMD-DK: 1
Meta_KL: 0 MMD-DK: 1
Meta_KL: 1 MMD-DK: 1
Meta_KL: 1 MMD-DK: 1
Meta_KL: 1 MMD-DK: 1
Meta_KL: 1 MMD-DK: 1
Meta_KL: 1 MMD-DK: 1
Meta_KL: 1 MMD-DK: 1
Meta_KL: 1 MMD-DK: 1
Meta_KL: 1 MMD-DK: 1
Meta_KL: 2 MMD-DK: 2
Meta_KL: 2 MMD-DK: 2
Meta_KL: 3 MMD-DK: 2
Meta_KL: 3 MMD-DK: 2
Meta_KL: 3 MMD-DK: 2
Meta_KL: 3 MMD-DK: 2
Meta_KL: 3 MMD-DK: 2
Meta_KL: 4 MMD-DK: 3
Meta_KL: 5 MMD-DK: 3
Meta_KL: 5 MMD-DK: 3
Meta_KL: 6 MMD-DK: 3
Meta_KL: 6 MMD-DK: 3
Meta_KL: 6 MMD-DK: 3
Meta_KL: 6 MMD-DK: 3
Meta_KL: 7 MMD-DK: 4
Meta_KL: 7 MMD-DK: 4
Meta_KL: 7 MMD-DK: 5
Meta_KL: 7 MMD-DK: 5
Meta_KL: 7 MMD-DK: 5
Meta_KL: 7 MMD-DK: 6
Meta_KL: 7 MMD-DK: 6
Meta_KL: 7 MMD-DK: 6
Meta_KL: 7 MMD-DK: 6
Meta_KL: 7 MMD-DK: 6
Meta_KL: 7 MMD-DK: 6
Meta_KL: 7 MMD-DK: 6
Meta_KL: 7 MMD-DK: 6
Meta_KL: 7 MMD-DK: 6
Meta_KL: 8 MMD-DK: 6
Meta_KL: 8 MMD-DK: 6
Meta_KL: 8 MMD-DK: 6
Meta_KL: 8 MMD-DK: 7
Meta_KL: 8 MMD-DK: 7
Meta_KL: 8 MMD-DK: 7
Meta_KL: 8 MMD-DK: 7
Meta_KL: 8 MMD-DK: 7
Meta_KL: 8 MM

In [15]:
def get_kernels(maml, epoch, lr, data, device, dtype):
    num_meta_tasks = data.shape[0]
    models = []
    # Feas = []
    sigmas = []
    sigma0s = []
    epsilons = []
    for i in range(num_meta_tasks):
        model_t, sigmaOPT_init_t, sigma0OPT_init_t, epsilonOPT_init_t = DkTST(maml, epoch, lr, torch.Tensor(data_org[i]).to(device, dtype), device, dtype)    
        # Feas.append(model_t(S))
        models.append(model_t)
        sigmas.append(sigmaOPT_init_t)
        sigma0s.append(sigma0OPT_init_t)
        epsilons.append(epsilonOPT_init_t)
    return models, sigmas, sigma0s, epsilons

In [16]:
models, sigmas, sigma0s, epsilons = get_kernels(maml_init, epoch, learning_rate, data_org, device, dtype)

mmd_value_init:  6.293340993579477e-05 mmd_std_init:  0.00913394894450903 Statistic_init:  0.006890054792165756
mmd_value_init:  0.0003567013191059232 mmd_std_init:  0.008242054842412472 Statistic_init:  0.043278202414512634
mmd_value_init:  0.0005091156926937401 mmd_std_init:  0.007482580374926329 Statistic_init:  0.06804012507200241
mmd_value_init:  0.0006213310407474637 mmd_std_init:  0.0049499766901135445 Statistic_init:  0.12552201747894287
mmd_value_init:  0.0007816457655280828 mmd_std_init:  0.004023012239485979 Statistic_init:  0.19429366290569305
mmd_value_init:  0.0006359061226248741 mmd_std_init:  0.0027990397065877914 Statistic_init:  0.22718724608421326
mmd_value_init:  0.0005541521240957081 mmd_std_init:  0.002309071132913232 Statistic_init:  0.23998919129371643
mmd_value_init:  0.0005432469770312309 mmd_std_init:  0.0022297652903944254 Statistic_init:  0.2436341494321823
mmd_value_init:  0.0005570618668571115 mmd_std_init:  0.002275962382555008 Statistic_init:  0.2447588

weights = np.random.randint(2, size=num_meta_tasks)
weights = torch.Tensor(weights).to(device, dtype)
h_t, threshold_t, mmd_value_t = TST_MMD_Multi(weights, Feas, N_per, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype)
# TST_MMD_u(Fea, N_per, N1, Fea_org, sigma, sigma0, epsilon, alpha, device, dtype, gamma=2, is_smooth=True)
# TST_MMD_u(model_u_init(S), N_per, N1_te2, S, sigma_init, sigma0_u_init, ep_init, alpha, device, dtype)

In [17]:
def estimate_power(weights, N, N_per, n_te2, models, sigmas, sigma0s, epsilons, device, dtype):
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    np.random.seed(1102)
    N_te2 = n_te2*2
    N_f = N*1.0
    kk = 0
    count_u = 0
    n = 50
    for k in range(N):
        # Generate target tasks
        for i in range(Num_clusters):
            np.random.seed(seed=1102 * (k+2) + 2*kk + i + n)
            s1_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        for i in range(Num_clusters):
            np.random.seed(seed=819 * (k + 1) + 2*kk + i + n)
            s2_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx_s[i], sigma_mx_2[i], n_te2)
        S = np.concatenate((s1_te2, s2_te2), axis=0)
        S = MatConvert(S, device, dtype)
        Feas = []
        for i in range(num_meta_tasks):
            Fea_t = models[i](S)
            Feas.append(Fea_t)
        # Run two sample test (deep kernel) on generated data
        h_u, threshold_u, mmd_value_u = TST_MMD_Multi(weights, Feas, N_per, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype)

        # Gather results
        count_u = count_u + h_u
        H_u[k] = h_u
        T_u[k] = threshold_u
        M_u[k] = mmd_value_u

    # Print test power of MetaKL and MMD-D
    print("Test Power:", H_u.sum() / N_f)
    Result = H_u.sum() / N_f
    return Result

In [18]:
weights = np.random.randint(2, size=num_meta_tasks)
weights = torch.Tensor(weights).to(device, dtype)
estimate_power(weights, N, N_per, n_te2, models, sigmas, sigma0s, epsilons, device, dtype)

Test Power: 0.61


0.61

In [None]:
class MMDLoss:
    def __init__(self, N, N_per, n_te, models, sigmas, sigma0s, epsilons):
        num_kernels = model.shape[0]
        self.weights = np.ones(num_kernels)
        self.N = N
        self.N_per = N_per
        self.n_te = net
        self.models = models
        self.sigmas = sigmas
        self.sigma0s = sigma0s
        self.epsilons = epsilons
        def fnet_sum_single(params, x):
            x = x.unsqueeze(0)
            f_sum = self.fnet_sum(params, x)
            f_sum = f_sum.squeeze(0)
            return f_sum
            # return self.fnet_sum(params, x).squeeze(0)

        def fnet_pred_single(params, x):
            x = x.unsqueeze(0)
            f_pred = self.fnet_pred(params, x)
            f_pred = f_pred.squeeze(0)
            return f_pred
            # return self.fnet_pred(params, x).squeeze(0)

        self.fnet_sum_single = fnet_sum_single
        self.fnet_pred_single = fnet_pred_single

    def masking_loss(self, x_train):
        full_pred = vmap(self.fnet_pred_single, (None, 0))(self.net_params, x_train)
        masked_pred = vmap(self.fnet_pred_single, (None, 0))(self.masked_params, x_train)

        full_ntk = empirical_ntk_batch_cross(self.fnet_sum_single, self.net_params, x_train)
        masked_ntk = empirical_ntk_batch_cross(self.fnet_sum_single, self.masked_params, x_train)

        batch_size = x_train.shape[0]
        pred_term = (1. / batch_size) * pt.sum((full_pred - masked_pred) ** 2)
        ntk_term = (self.gamma_1 / batch_size) ** 2 * pt.sum((full_ntk - masked_ntk) ** 2)
        mask_term = self.gamma_2 * pt.sum(pt.stack([pt.sum(pt.abs(m)) for m in self.masks]))

        loss = pred_term + ntk_term + mask_term
        return loss