In [1]:
import  torch
import  numpy as np
from dataTST import SynDataMetaTST
from MetaTST import Meta
import argparse
parser = argparse.ArgumentParser()
from utils import MatConvert, TST_MMD_u, MMDu, TST_MMD_Multi, TST_Wald, TST_Ost, MMDg, get_Analytic_Weights, Analytic_Weights
from generate_data import *
from learner import Learner

import matplotlib.pyplot as plt
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import pickle
import os
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets

%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
def mem_ratio():
    print(torch.cuda.memory_allocated() / torch.cuda.memory_reserved())

In [3]:
class myargs:
    def __init__(self):
        self.n=50
        self.n_te=150
        self.d=2
        self.K=10
        self.num_meta_tasks=100
        self.epoch=1000
        self.meta_lr=1e-2
        self.update_lr=0.8
        self.update_step=10
        self.closeness=0.3
args = myargs()

In [4]:
def mu_sigma(delta):
    Num_clusters = 2
    mu = np.zeros([Num_clusters, d])
    mu[1] = mu[1] + 0.5
    sigma = [np.identity(d), np.identity(d)]
    sigma[0][0, 1] = delta
    sigma[0][1, 0] = delta
    sigma[1][0, 1] = -delta
    sigma[1][1, 0] = -delta
    return mu, sigma

In [5]:
def gen_data(n, delta = 0.7, kk=0):
    mu_mx_1, sigma_mx_1 = mu_sigma(0)
    mu_mx_2, sigma_mx_2 = mu_sigma(delta)
    s1 = np.zeros([n * Num_clusters, d])
    s2 = np.zeros([n * Num_clusters, d])
    for i in range(Num_clusters):
        np.random.seed(seed=1102*kk + i + n)
        s1[n * (i):n * (i + 1), :] = np.random.multivariate_normal(mu_mx_1[i], sigma_mx_1[i], n)
    for i in range(Num_clusters):
        np.random.seed(seed=819*kk + 1 + i + n)
        s2[n * (i):n * (i + 1), :] = np.random.multivariate_normal(mu_mx_2[i], sigma_mx_2[i], n)
    S = np.concatenate((s1, s2), axis=0)
    return S, s1, s2

In [6]:
def gen_meta_data(num_meta_tasks, n, closeness):
    # generate meta-samples
    data_org = np.random.randn(num_meta_tasks,4*n,2)
    for nn in range(num_meta_tasks):
        delta= 0.6 - closeness + (0.1/num_meta_tasks) * (nn+1)
        data_org[nn] = gen_data(n,delta,nn)[0]
    
    return data_org

In [7]:
def alg_one(maml, db_train, epoch):
    for step in range(epoch):
        x_spt, y_spt, x_qry, y_qry = db_train.next()
        x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                     torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)
        # train meta kernels
        J_value, model_u, sigma, sigma0_u, ep = maml(x_spt, y_spt, x_qry, y_qry) 
        # print objectives from epoch
        if step % 10 == 0:
            print('step:', step, '\ttraining J value:', J_value.item())
    return model_u, sigma, sigma0_u, ep

In [8]:
def get_no_trainable_tensors(model):
    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print('Total trainable tensors:', num)
    return num

In [9]:
def fine_tune(model, it, learning_rate, S, sigma, sigma0, ep, device, dtype):
    n_te = int(S.shape[0]/4)
    # setup optimizer for training deep kernel
    optimizer = torch.optim.Adam(list(model.parameters()) + [ep] + [sigma] + [sigma0],
                                   lr=learning_rate)
    J_star = np.zeros([it])
    for t in range(it):
        # one way to train kernel with limited data
        n_random = int(n_te*Num_clusters/5)
        selected_cls1 = np.random.choice(n_te * Num_clusters, n_random, False)
        selected_cls2 = np.random.choice(n_te * Num_clusters, n_random, False)
        s1_te_random = s1_te[selected_cls1,:]
        s2_te_random = s2_te[selected_cls2, :]
        S_random = np.concatenate((s1_te_random, s2_te_random), axis=0)
        S_random = MatConvert(S_random, device, dtype)
        # another way to train kernel with limited data (similar performance)
        # S_random = S
        # n_random = N1

        # Compute epsilon, sigma and sigma_0
        ep_ = ep ** 2
        sigma_ = sigma ** 2
        sigma0_ = sigma0 ** 2

        # Compute output of the deep network
        model_output = model(S_random)

        # Compute J (STAT_u)
        TEMP = MMDu(model_output, n_random, S_random, sigma_, sigma0_, ep_)
        mmd_value_temp = -1 * (TEMP[0] + 10 ** (-8))
        mmd_std_temp = torch.sqrt(TEMP[1] + 10 ** (-8))
        if mmd_std_temp.item() == 0:
            print('error!!')
        if np.isnan(mmd_std_temp.item()):
            print('error!!')
        STAT = torch.div(mmd_value_temp, mmd_std_temp)
        J_star[t] = STAT.item()

        # Initialize optimizer and Compute gradient
        optimizer.zero_grad()
        STAT.backward(retain_graph=True)

        # Update weights using gradient descent
        optimizer.step()
        # Print MMD, std of MMD and J
        if t % 100 == 0:
            print("mmd_value: ", -1 * mmd_value_temp.item(), "mmd_std: ", mmd_std_temp.item(), "Statistic: ",
                  -1 * STAT.item())
    return model, sigma, sigma0, ep

In [10]:
def DkTST(maml_init, epoch, learning_rate, S, device, dtype):
    _,model, sigma_init, sigma0_u_init, ep_init = maml_init.get_init()
    torch.manual_seed(1 * 19 + n)
    torch.cuda.manual_seed(kk * 19 + n)
    ep = MatConvert(np.ones(1) * np.sqrt(ep_init.detach().cpu().numpy()), device, dtype)
    ep.requires_grad = True
    sigma = MatConvert(np.ones(1) * np.sqrt(sigma_init.detach().cpu().numpy()), device, dtype)
    sigma.requires_grad = True
    sigma0 = MatConvert(np.ones(1) * np.sqrt(sigma0_u_init.detach().cpu().numpy()), device, dtype)
    sigma0.requires_grad = True
    # Setup optimizer for training init kernel
    optimizer = torch.optim.Adam(list(model.parameters()) + [ep] + [sigma] + [sigma0],
                               lr=learning_rate)
    J_star = np.zeros([N_epoch])
    N1 = int(S.shape[0]/2)
    for t in range(epoch):
        # Compute epsilon, sigma and sigma_0
        ep_ = ep ** 2
        sigma_ = sigma ** 2
        sigma0_ = sigma0 ** 2
        # Compute output of the deep network
        model_output = model(S)
        # Compute J (STAT_u)
        TEMP = MMDu(model_output, N1, S, sigma_, sigma0_, ep_)
        mmd_value_temp = -1 * (TEMP[0] + 10 ** (-8))
        mmd_std_temp = torch.sqrt(TEMP[1] + 10 ** (-8))
        if mmd_std_temp.item() == 0:
            print('error!!')
        if np.isnan(mmd_std_temp.item()):
            print('error!!')
        STAT = torch.div(mmd_value_temp, mmd_std_temp)
        # STAT_u = mmd_value_temp # D+M
        J_star[t] = STAT.item()
        # Initialize optimizer and Compute gradient
        optimizer.zero_grad()
        STAT.backward(retain_graph=True)
        # Update weights using gradient descent
        optimizer.step()
        # Print MMD, std of MMD and J
        if t % 100 == 0:
            print("mmd_value_init: ", -1 * mmd_value_temp.item(), "mmd_std_init: ", mmd_std_temp.item(), "Statistic_init: ",
                  -1 * STAT.item())
    return model, sigma, sigma0, ep

In [11]:
def get_kernels_bimodal(maml, epoch, lr, data, device, dtype):
    num_meta_tasks = data.shape[0]
    models = []
    # Feas = []
    sigmas = []
    sigma0s = []
    epsilons = []
    for i in range(num_meta_tasks):
        model_t, sigmaOPT_init_t, sigma0OPT_init_t, epsilonOPT_init_t = DkTST(maml, epoch, lr, torch.Tensor(data_org[i]).to(device, dtype), device, dtype)    
        # Feas.append(model_t(S))
        models.append(model_t)
        sigmas.append(sigmaOPT_init_t.detach())
        sigma0s.append(sigma0OPT_init_t.detach())
        epsilons.append(epsilonOPT_init_t.detach())
    
    for model in models:
        model.eval()
        
    return models, sigmas, sigma0s, epsilons

In [12]:
def save_kernels(models, sigmas, sigma0s, epsilons):
    num_meta_tasks = len(sigmas)
    for i in range(num_meta_tasks):
        state = {0: models[i].state_dict(),
            1: sigmas[i],
            2: sigma0s[i],
            3: epsilons[i]}
        torch.save(state, './kernels/Bimodal_Kernels/bimodal_' + str(i) + '.pt')

In [13]:
def load_kernels_bimodal(num_meta_tasks,config):
    models, sigmas, sigma0s, epsilons = [], [], [], []
    for i in range(num_meta_tasks):
        saved_kernel = torch.load('./kernels/Bimodal_Kernels/bimodal_' + str(i) + '.pt')
        model_meta_nonAPT = Learner(config).cuda()
        model_meta_nonAPT.load_state_dict(saved_kernel[0])
        model_meta_double = model_meta_nonAPT.double()
        sigma, sigma0_u, ep = saved_kernel[1].double(), saved_kernel[2].double(), saved_kernel[3].double()
        sigma_meta_nonAPT, sigma0_u_meta_nonAPT, ep_meta_nonAPT = sigma, sigma0_u, ep
        models.append(model_meta_double)
        sigmas.append(sigma_meta_nonAPT.detach())
        sigma0s.append(sigma0_u_meta_nonAPT.detach())
        epsilons.append(ep_meta_nonAPT.detach())
    return models, sigmas, sigma0s, epsilons

In [14]:
def load_kernels_cifar(num_meta_tasks,config):
    models, sigmas, sigma0s, epsilons = [], [], [], []
    for i in range(num_meta_tasks):
        saved_kernel = torch.load('./kernels/CIFAR10_Kernels/mkl_kernel_CIFAR10_{}_64.pt'.format(i))
        # import IPython; IPython.embed()
        model_meta_nonAPT = Learner(config).cuda()
        model_meta_nonAPT.load_state_dict(saved_kernel[0])
        model_meta_double = model_meta_nonAPT.double()
        sigma, sigma0_u, ep = saved_kernel[1].double(), saved_kernel[2].double(), saved_kernel[3].double()
        sigma_meta_nonAPT, sigma0_u_meta_nonAPT, ep_meta_nonAPT = sigma, sigma0_u, ep
        models.append(model_meta_double)
        sigmas.append(sigma_meta_nonAPT.detach())
        sigma0s.append(sigma0_u_meta_nonAPT.detach())
        epsilons.append(ep_meta_nonAPT.detach())
    return models, sigmas, sigma0s, epsilons

In [15]:
torch.manual_seed(222)
torch.cuda.manual_seed_all(222)
np.random.seed(222)

dtype = torch.double
device = torch.device("cuda:0")

d = args.d  # dimension of data
n = args.n  # number of samples in per mode
n_te = args.n_te # number of training samples for the target task
K = args.K  # number of trails
num_meta_tasks = args.num_meta_tasks # number of meta-samples
print('n: ' + str(n) + ' d: ' + str(d))

N_per = 100  # permutation times
alpha = 0.05  # test threshold
x_in = d  # number of neurons in the input layer, i.e., dimension of data
H = 30  # number of neurons in the hidden layer
x_out = 3 * d  # number of neurons in the output layer
learning_rate = 0.00005 # learning rate for MMD-D
N_epoch = 1000  # maximim number of epochs for training
N = 100  # # number of test sets
N_f = 100.0  # number of test sets (float)
list_nte = [50, 80, 100, 120, 150, 200, 250] # number of test samples for the target task

config = [
    ('linear', [H, x_in]),
    ('softplus', [True]),
    ('linear', [H, H]),
    ('softplus', [True]),
    ('linear', [H, H]),
    ('softplus', [True]),
    ('linear', [x_out, H]),
]

print(args)

# Generate variance and co-variance matrix of Q (target task)
Num_clusters = 2
mu_mx = np.zeros([Num_clusters, d])
mu_mx[1] = mu_mx[1] + 0.5
sigma_mx_1 = np.identity(d)
sigma_mx_2 = [np.identity(d), np.identity(d)]
sigma_mx_2[0][0, 1] = 0.7
sigma_mx_2[0][1, 0] = 0.7
sigma_mx_2[1][0, 1] = -0.7
sigma_mx_2[1][1, 0] = -0.7

mu_mx_s = np.zeros([Num_clusters, d])
mu_mx_s[1] = mu_mx_s[1] + 0.5
sigma_mx_1_s = np.identity(d)
sigma_mx_2_s = [np.identity(d), np.identity(d)]

# Naming variables
s1 = np.zeros([n * Num_clusters, d])
s2 = np.zeros([n * Num_clusters, d])
s1_te = np.zeros([n_te * Num_clusters, d])
s2_te = np.zeros([n_te * Num_clusters, d])
J_star_u = np.zeros([N_epoch])
J_star_u_init = np.zeros([N_epoch])
Results = np.zeros([len(list_nte), 2, K])

<torch._C.Generator at 0x7f58fbf5e750>

n: 50 d: 2
<__main__.myargs object at 0x7f59202a2d00>


kk=0
maml = Meta(args, config).to(device)
maml_init = Meta(args, config).to(device)

get_no_trainable_tensors(maml)
get_no_trainable_tensors(maml_init)

# generate meta-samples
data_org = np.random.randn(num_meta_tasks,4*n,2)
for nn in range(num_meta_tasks):
    sigma_mx_2_s[0][0, 1] = 0.6 - args.closeness + (0.1/num_meta_tasks) * (nn+1)
    sigma_mx_2_s[0][1, 0] = 0.6 - args.closeness + (0.1/num_meta_tasks) * (nn+1)
    sigma_mx_2_s[1][0, 1] = -(0.6 - args.closeness + (0.1/num_meta_tasks) * (nn+1))
    sigma_mx_2_s[1][1, 0] = -(0.6 - args.closeness + (0.1/num_meta_tasks) * (nn+1))

    for i in range(Num_clusters):
        np.random.seed(seed=1102*nn + i + n)
        s1[n * (i):n * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1_s, n)
    for i in range(Num_clusters):
        np.random.seed(seed=819*nn + 1 + i + n)
        s2[n * (i):n * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_2_s[i], n)
    data_org[nn] = np.concatenate((s1, s2), axis=0)
# get training loader for meta-samples
db_train = SynDataMetaTST(data_org, 10, 2, 150, 50)

model_u, sigma, sigma0_u, ep = alg_one(maml, db_train, args.epoch)

# setup meta kernels
torch.manual_seed(1 * 19 + n)
torch.cuda.manual_seed(kk * 19 + n)
epsilonOPT = MatConvert(np.ones(1) * np.sqrt(ep.detach().cpu().numpy()), device, dtype)
epsilonOPT.requires_grad = True
sigmaOPT = MatConvert(np.ones(1) * np.sqrt(sigma.detach().cpu().numpy()), device, dtype)
sigmaOPT.requires_grad = True
sigma0OPT = MatConvert(np.ones(1) * np.sqrt(sigma0_u.detach().cpu().numpy()), device, dtype)
sigma0OPT.requires_grad = True
print(epsilonOPT.item())

# Generate training data for target tasks
# S = gen_data(n_te)[0]
for i in range(Num_clusters):
    np.random.seed(seed=1102*kk + i + n)
    s1_te[n_te * (i):n_te * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te)
for i in range(Num_clusters):
    np.random.seed(seed=819*kk + 1 + i + n)
    s2_te[n_te * (i):n_te * (i + 1), :] = np.random.multivariate_normal(mu_mx_s[i], sigma_mx_2[i], n_te)
S = np.concatenate((s1_te, s2_te), axis=0)
S = MatConvert(S, device, dtype)
N1 = Num_clusters*n_te

# Meta Kernels as init when training with training set from the target task
np.random.seed(seed=1102)
torch.manual_seed(1102)
torch.cuda.manual_seed(1102)
it=101
model_u, sigmaOPT, sigma0OPT, epsilonOPT = fine_tune(model_u, it, learning_rate/3, S, sigmaOPT,  sigma0OPT, epsilonOPT, device, dtype)

#page 6
# random kernel as init when training with training set from the target task --
# --> validate the consistence of performance of MMD-D
# setup init Kernels 
# print(epsilonOPT_init.item())
# epoch = 1000
epoch = 1000
_,model_u_init, sigma_init, sigma0_u_init, ep_init = maml_init.get_init()
model_u_init, sigmaOPT_init, sigma0OPT_init, epsilonOPT_init = DkTST(maml_init, epoch, learning_rate, S, device, dtype)

kk=0
# test the trained kernel on the target task (with different sample size: 50, 80, 100, 120, 150, 200, 250)
for i_test in range(len(list_nte)):
    n_te2 = list_nte[i_test]
    s1_te2 = np.zeros([n_te2 * Num_clusters, d])
    s2_te2 = np.zeros([n_te2 * Num_clusters, d])
    N1_te2 = Num_clusters * n_te2
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    H_u_init = np.zeros(N)
    T_u_init = np.zeros(N)
    M_u_init = np.zeros(N)
    np.random.seed(1102)
    count_u = 0
    count_u_init = 0
    for k in range(N):
        # Generate target tasks
        for i in range(Num_clusters):
            np.random.seed(seed=1102 * (k+2) + 2*kk + i + n)
            s1_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        for i in range(Num_clusters):
            np.random.seed(seed=819 * (k + 1) + 2*kk + i + n)
            s2_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx_s[i], sigma_mx_2[i], n_te2)
        S = np.concatenate((s1_te2, s2_te2), axis=0)
        S = MatConvert(S, device, dtype)

        # Run two sample test (deep kernel) on generated data
        h_u, threshold_u, mmd_value_u = TST_MMD_u(model_u(S), N_per, N1_te2, S, sigma, sigma0_u, ep, alpha, device, dtype)
        h_u_init, threshold_u_init, mmd_value_u_init = TST_MMD_u(model_u_init(S), N_per, N1_te2, S, sigma_init, sigma0_u_init, ep_init, alpha, device, dtype)

        # Gather results
        count_u = count_u + h_u
        count_u_init = count_u_init + h_u_init
        print("Meta_KL:", count_u, "MMD-DK:", count_u_init)
        H_u[k] = h_u
        T_u[k] = threshold_u
        M_u[k] = mmd_value_u
        H_u_init[k] = h_u_init
        T_u_init[k] = threshold_u_init
        M_u_init[k] = mmd_value_u_init

    # Print test power of MetaKL and MMD-D
    print("Test Power of Meta MMD: ", H_u.sum() / N_f)
    Results[i_test, 0, kk] = H_u.sum() / N_f
    print("Test Power of Meta MMD (K times): ", Results[i_test, 0])
    print("Average Test Power of Meta MMD: ", Results[i_test, 0].sum() / (kk + 1))

    print("Test Power of deep MMD: ", H_u_init.sum() / N_f)
    Results[i_test, 1, kk] = H_u_init.sum() / N_f
    print("Test Power of deep MMD (K times): ", Results[i_test, 1])
    print("Average Test Power of deep MMD: ", Results[i_test, 1].sum() / (kk + 1))

print(Results[:,:,kk])

In [16]:
# models, sigmas, sigma0s, epsilons = get_kernels(maml_init, epoch, learning_rate, data_org, device, dtype)
# save_kernels(models, sigmas, sigma0s, epsilons)

models, sigmas, sigma0s, epsilons = load_kernels_bimodal(num_meta_tasks,config)

weights = np.random.randint(2, size=num_meta_tasks)
weights = torch.Tensor(weights).to(device, dtype)
h_t, threshold_t, mmd_value_t = TST_MMD_Multi(weights, Feas, N_per, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype)
# TST_MMD_u(Fea, N_per, N1, Fea_org, sigma, sigma0, epsilon, alpha, device, dtype, gamma=2, is_smooth=True)
# TST_MMD_u(model_u_init(S), N_per, N1_te2, S, sigma_init, sigma0_u_init, ep_init, alpha, device, dtype)

In [17]:
def estimate_power(weights, N, N_per, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    s1_te2 = np.zeros([n_te2 * 2, d])
    s2_te2 = np.zeros([n_te2 * 2, d])
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    Num_clusters = 2
    np.random.seed(1102)
    N1_te2 = n_te2*2
    N_f = N*1.0
    kk = 0
    count_u = 0
    n = 50
    for k in range(N):
        # Generate target tasks
        for i in range(Num_clusters):
            np.random.seed(seed=1102 * (k+2) + 2*kk + i + n)
            s1_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        for i in range(Num_clusters):
            np.random.seed(seed=819 * (k + 1) + 2*kk + i + n)
            s2_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx_s[i], sigma_mx_2[i], n_te2)
        S = np.concatenate((s1_te2, s2_te2), axis=0)
        S = MatConvert(S, device, dtype)
        Feas = []
        for i in range(num_meta_tasks):
            Fea_t = models[i](S.detach())
            Feas.append(Fea_t.detach())
        # Run two sample test (deep kernel) on generated data
        h_u, threshold_u, mmd_value_u = TST_MMD_Multi(weights, Feas, N_per, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype, gamma=gamma)

        # Gather results
        count_u = count_u + h_u
        H_u[k] = h_u
        T_u[k] = threshold_u
        M_u[k] = mmd_value_u

    # Print test power of MetaKL and MMD-D
    print("Test Power:", H_u.sum() / N_f)
    Result = H_u.sum() / N_f
    return Result

In [18]:
def estimate_power_wald(N, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    s1_te2 = np.zeros([n_te2 * 2, d])
    s2_te2 = np.zeros([n_te2 * 2, d])
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    Num_clusters = 2
    np.random.seed(1102)
    N1_te2 = n_te2*2
    N_f = N*1.0
    kk = 0
    count_u = 0
    n = 50
    for k in range(N):
        # Generate target tasks
        for i in range(Num_clusters):
            np.random.seed(seed=1102 * (k+2) + 2*kk + i + n)
            s1_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        for i in range(Num_clusters):
            np.random.seed(seed=819 * (k + 1) + 2*kk + i + n)
            s2_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx_s[i], sigma_mx_2[i], n_te2)
        S = np.concatenate((s1_te2, s2_te2), axis=0)
        S = MatConvert(S, device, dtype)
        Feas = []
        for i in range(num_meta_tasks):
            Fea_t = models[i](S.detach())
            Feas.append(Fea_t.detach())
        # Run two sample test (deep kernel) on generated data
        h_u, _, _ = TST_Wald(Feas, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype, gamma=gamma)

        # Gather results
        count_u = count_u + h_u
        H_u[k] = h_u
        # T_u[k] = threshold_u
        # M_u[k] = mmd_value_u

    # Print test power of MetaKL and MMD-D
    print("Test Power:", H_u.sum() / N_f)
    Result = H_u.sum() / N_f
    return Result

In [19]:
def estimate_power_ost(N, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    s1_te2 = np.zeros([n_te2 * 2, d])
    s2_te2 = np.zeros([n_te2 * 2, d])
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    Num_clusters = 2
    np.random.seed(1102)
    N1_te2 = n_te2*2
    N_f = N*1.0
    kk = 0
    count_u = 0
    n = 50
    for k in range(N):
        # Generate target tasks
        for i in range(Num_clusters):
            np.random.seed(seed=1102 * (k+2) + 2*kk + i + n)
            s1_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        for i in range(Num_clusters):
            np.random.seed(seed=819 * (k + 1) + 2*kk + i + n)
            s2_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx_s[i], sigma_mx_2[i], n_te2)
        S = np.concatenate((s1_te2, s2_te2), axis=0)
        S = MatConvert(S, device, dtype)
        Feas = []
        for i in range(num_meta_tasks):
            Fea_t = models[i](S.detach())
            Feas.append(Fea_t.detach())
        # Run two sample test (deep kernel) on generated data
        h_u, _, _ = TST_Ost(Feas, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype, gamma=gamma)

        # Gather results
        count_u = count_u + h_u
        H_u[k] = h_u
        # T_u[k] = threshold_u
        # M_u[k] = mmd_value_u

    # Print test power of MetaKL and MMD-D
    print("Test Power:", H_u.sum() / N_f)
    Result = H_u.sum() / N_f
    return Result

In [20]:
def estimate_power_wald_draft(N, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    Num_clusters = 2
    np.random.seed(1102)
    # N1_te2 = n_te2*2
    N1_te2 = n_te2*1
    N_f = N*1.0
    kk = 0
    count_u = 0
    n = 50
    for k in range(N):
        # Generate target tasks
        dataset = 'cifar'
        hypothesis = 'alternative'
        samplesize = n_te2
        x, y = generate_samples(dataset, hypothesis, samplesize,data_all=data_all,data_trans=data_trans)
        S = torch.cat([x, y], dim=0).to(device, dtype)
        Feas = []
        for i in range(num_meta_tasks):
            Fea_t = models[i](S.detach())
            Feas.append(Fea_t.detach())
        # Run two sample test (deep kernel) on generated data
        h_u, _, _ = TST_Wald(Feas, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype, gamma=gamma)

        # Gather results
        count_u = count_u + h_u
        H_u[k] = h_u
        # T_u[k] = threshold_u
        # M_u[k] = mmd_value_u

    # Print test power of MetaKL and MMD-D
    print("Test Power:", H_u.sum() / N_f)
    Result = H_u.sum() / N_f
    return Result

In [21]:
def estimate_power_ost_draft(N, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    Num_clusters = 2
    np.random.seed(1102)
    # N1_te2 = n_te2*2
    N1_te2 = n_te2*1
    N_f = N*1.0
    kk = 0
    count_u = 0
    n = 50
    for k in range(N):
        # Generate target tasks
        dataset = 'cifar'
        hypothesis = 'alternative'
        samplesize = n_te2
        x, y = generate_samples(dataset, hypothesis, samplesize,data_all=data_all,data_trans=data_trans)
        S = torch.cat([x, y], dim=0).to(device, dtype)
        Feas = []
        for i in range(num_meta_tasks):
            Fea_t = models[i](S.detach())
            Feas.append(Fea_t.detach())
        # Run two sample test (deep kernel) on generated data
        h_u, _, _ = TST_Ost(Feas, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype, gamma=gamma)

        # Gather results
        count_u = count_u + h_u
        H_u[k] = h_u
        # T_u[k] = threshold_u
        # M_u[k] = mmd_value_u

    # Print test power of MetaKL and MMD-D
    print("Test Power:", H_u.sum() / N_f)
    Result = H_u.sum() / N_f
    return Result

In [22]:
def estimate_type1(weights, N, N_per, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    s1_te2 = np.zeros([n_te2 * 2, d])
    s2_te2 = np.zeros([n_te2 * 2, d])
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    Num_clusters = 2
    np.random.seed(1102)
    N1_te2 = n_te2*2
    N_f = N*1.0
    kk = 0
    count_u = 0
    n = 50
    for k in range(N):
        # Generate target tasks
        for i in range(Num_clusters):
            np.random.seed(seed=1102 * (k+2) + 2*kk + i + n)
            s1_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        for i in range(Num_clusters):
            np.random.seed(seed=819 * (k + 1) + 2*kk + i + n)
            s2_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx_s[i], sigma_mx_1, n_te2)
        S = np.concatenate((s1_te2, s2_te2), axis=0)
        S = MatConvert(S, device, dtype)
        Feas = []
        for i in range(num_meta_tasks):
            Fea_t = models[i](S.detach())
            Feas.append(Fea_t.detach())
        # Run two sample test (deep kernel) on generated data
        h_u, threshold_u, mmd_value_u = TST_MMD_Multi(weights, Feas, N_per, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype, gamma=gamma)

        # Gather results
        count_u = count_u + h_u
        H_u[k] = h_u
        T_u[k] = threshold_u
        M_u[k] = mmd_value_u

    # Print test power of MetaKL and MMD-D
    print("Test Power:", H_u.sum() / N_f)
    Result = H_u.sum() / N_f
    return Result

In [23]:
def estimate_type1_wald(N, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    s1_te2 = np.zeros([n_te2 * 2, d])
    s2_te2 = np.zeros([n_te2 * 2, d])
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    Num_clusters = 2
    np.random.seed(1102)
    N1_te2 = n_te2*2
    N_f = N*1.0
    kk = 0
    count_u = 0
    n = 50
    for k in range(N):
        # Generate target tasks
        for i in range(Num_clusters):
            np.random.seed(seed=1102 * (k+2) + 2*kk + i + n)
            s1_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        for i in range(Num_clusters):
            np.random.seed(seed=819 * (k + 1) + 2*kk + i + n)
            s2_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        S = np.concatenate((s1_te2, s2_te2), axis=0)
        S = MatConvert(S, device, dtype)
        Feas = []
        for i in range(num_meta_tasks):
            Fea_t = models[i](S.detach())
            Feas.append(Fea_t.detach())
        # Run two sample test (deep kernel) on generated data
        h_u, _, _ = TST_Wald(Feas, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype, gamma=gamma)

        # Gather results
        count_u = count_u + h_u
        H_u[k] = h_u
        # T_u[k] = threshold_u
        # M_u[k] = mmd_value_u

    # Print test power of MetaKL and MMD-D
    print("Test Power:", H_u.sum() / N_f)
    Result = H_u.sum() / N_f
    return Result

In [24]:
def estimate_type1_wald_draft(N, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    Num_clusters = 2
    np.random.seed(1102)
    # N1_te2 = n_te2*2
    N1_te2 = n_te2*1
    N_f = N*1.0
    kk = 0
    count_u = 0
    n = 50
    for k in range(N):
        # Generate target tasks
        dataset = 'cifar'
        hypothesis = 'null'
        samplesize = n_te2
        x, y = generate_samples(dataset, hypothesis, samplesize,data_all=data_all,data_trans=data_trans)
        S = torch.cat([x, y], dim=0).to(device, dtype)
        Feas = []
        for i in range(num_meta_tasks):
            Fea_t = models[i](S.detach())
            Feas.append(Fea_t.detach())
        # Run two sample test (deep kernel) on generated data
        h_u, _, _ = TST_Wald(Feas, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype, gamma=gamma)

        # Gather results
        count_u = count_u + h_u
        H_u[k] = h_u
        # T_u[k] = threshold_u
        # M_u[k] = mmd_value_u

    # Print test power of MetaKL and MMD-D
    print("Test Power:", H_u.sum() / N_f)
    Result = H_u.sum() / N_f
    return Result

In [25]:
def estimate_type1_ost(N, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    s1_te2 = np.zeros([n_te2 * 2, d])
    s2_te2 = np.zeros([n_te2 * 2, d])
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    Num_clusters = 2
    np.random.seed(1102)
    N1_te2 = n_te2*2
    N_f = N*1.0
    kk = 0
    count_u = 0
    n = 50
    for k in range(N):
        # Generate target tasks
        for i in range(Num_clusters):
            np.random.seed(seed=1102 * (k+2) + 2*kk + i + n)
            s1_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        for i in range(Num_clusters):
            np.random.seed(seed=819 * (k + 1) + 2*kk + i + n)
            s2_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        S = np.concatenate((s1_te2, s2_te2), axis=0)
        S = MatConvert(S, device, dtype)
        Feas = []
        for i in range(num_meta_tasks):
            Fea_t = models[i](S.detach())
            Feas.append(Fea_t.detach())
        # Run two sample test (deep kernel) on generated data
        h_u, _, _ = TST_Ost(Feas, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype, gamma=gamma)

        # Gather results
        count_u = count_u + h_u
        H_u[k] = h_u
        # T_u[k] = threshold_u
        # M_u[k] = mmd_value_u

    # Print test power of MetaKL and MMD-D
    print("Test Power:", H_u.sum() / N_f)
    Result = H_u.sum() / N_f
    return Result

In [26]:
def estimate_type1_ost_draft(N, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    Num_clusters = 2
    np.random.seed(1102)
    # N1_te2 = n_te2*2
    N1_te2 = n_te2*1
    N_f = N*1.0
    kk = 0
    count_u = 0
    n = 50
    for k in range(N):
        # Generate target tasks
        dataset = 'cifar'
        hypothesis = 'null'
        samplesize = n_te2
        x, y = generate_samples(dataset, hypothesis, samplesize,data_all=data_all,data_trans=data_trans)
        S = torch.cat([x, y], dim=0).to(device, dtype)
        Feas = []
        for i in range(num_meta_tasks):
            Fea_t = models[i](S.detach())
            Feas.append(Fea_t.detach())
        # Run two sample test (deep kernel) on generated data
        h_u, _, _ = TST_Ost(Feas, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype, gamma=gamma)

        # Gather results
        count_u = count_u + h_u
        H_u[k] = h_u
        # T_u[k] = threshold_u
        # M_u[k] = mmd_value_u

    # Print test power of MetaKL and MMD-D
    print("Test Power:", H_u.sum() / N_f)
    Result = H_u.sum() / N_f
    return Result

In [27]:
def get_weights(n_tr, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    Num_clusters = 2
    N_tr = n_tr*2
    kk = 0
    n=50
    np.random.seed(1102)
    s1_tr = np.zeros([N_tr, 2])
    s2_tr = np.zeros([N_tr, d])
    # Generate target task
    for i in range(Num_clusters):
        np.random.seed(seed=1102 * (k+2) + 2*kk + i + n)
        s1_tr[n_tr * (i):n_tr * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_tr)
    for i in range(Num_clusters):
        np.random.seed(seed=819 * (k + 1) + 2*kk + i + n)
        s2_tr[n_tr * (i):n_tr * (i + 1), :] = np.random.multivariate_normal(mu_mx_s[i], sigma_mx_2[i], n_tr)
    S_tr = np.concatenate((s1_tr, s2_tr), axis=0)
    S_tr = MatConvert(S_tr, device, dtype).detach()
    Feas = []
    for i in range(num_meta_tasks):
        with torch.no_grad():
            Fea_t = models[i](S_tr)
            Feas.append(Fea_t.detach())
    return get_Analytic_Weights(Feas, N_tr, S_tr, sigmas, sigma0s, epsilons, device, dtype, gamma=gamma)

In [28]:
def get_weights1(n_tr, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    Num_clusters = 2
    N_tr = n_tr*2
    kk = 0
    n=50
    np.random.seed(1102)
    s1_tr = np.zeros([N_tr, 2])
    s2_tr = np.zeros([N_tr, d])
    
    for model in models:
        model.eval()
        
    # Generate target task
    for i in range(Num_clusters):
        np.random.seed(seed=1102 * (k+2) + 2*kk + i + n)
        s1_tr[n_tr * (i):n_tr * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_tr)
    for i in range(Num_clusters):
        np.random.seed(seed=819 * (k + 1) + 2*kk + i + n)
        s2_tr[n_tr * (i):n_tr * (i + 1), :] = np.random.multivariate_normal(mu_mx_s[i], sigma_mx_1, n_tr)
    S_tr = np.concatenate((s1_tr, s2_tr), axis=0)
    S_tr = MatConvert(S_tr, device, dtype).detach()
    Feas = []
    for i in range(num_meta_tasks):
        with torch.no_grad():
            Fea_t = models[i](S_tr)
            Feas.append(Fea_t.detach())
    return get_Analytic_Weights(Feas, N_tr, S_tr, sigmas, sigma0s, epsilons, device, dtype, gamma=gamma)

In [29]:
def get_power(n, N,ratio, N_per, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    n_tr = n*ratio
    n_te = n-n_tr
    weights = get_weights(n_tr, models, sigmas, sigma0s, epsilons, device, dtype, gamma=gamma)
    return estimate_power(weights, N, N_per, n_te, models, sigmas, sigma0s, epsilons, device, dtype,gamma=gamma)

In [30]:
def get_type1(n, N,ratio, N_per, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    n_tr = n*ratio
    n_te = n-n_tr
    weights = get_weights1(n_tr, models, sigmas, sigma0s, epsilons, device, dtype, gamma=gamma)
    return estimate_type1(weights, N, N_per, n_te, models, sigmas, sigma0s, epsilons, device, dtype,gamma=gamma)

In [31]:
def experiment(n, N,ratio, N_per, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    type1 = get_type1(n, N,ratio, N_per, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2)
    power = get_power(n, N,ratio, N_per, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2)
    return power, type1

In [32]:
def check_thresh(N, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    s1_te2 = np.zeros([n_te2 * 2, d])
    s2_te2 = np.zeros([n_te2 * 2, d])
    num_meta_tasks = len(sigmas)
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    Num_clusters = 2
    np.random.seed(1102)
    N1_te2 = n_te2*2
    N_f = N*1.0
    kk = 0
    count_u = 0
    n = 50
    t = torch.zeros(N, dtype = dtype)
    for model in models:
        model.eval()

    for k in range(N):
        # Generate target tasks
        for i in range(Num_clusters):
            np.random.seed(seed=1102 * (k+2) + 2*kk + i + n)
            s1_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        for i in range(Num_clusters):
            np.random.seed(seed=819 * (k + 1) + 2*kk + i + n)
            s2_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx_s[i], sigma_mx_2[i], n_te2)
        S = np.concatenate((s1_te2, s2_te2), axis=0)
        S = MatConvert(S, device, dtype)
        Feas = []
        for i in range(num_meta_tasks):
            Fea_t = models[i](S)
            Feas.append(Fea_t.detach())
        # Run two sample test (deep kernel) on generated data
        mmd, V, _ = MMDg(Feas, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype, gamma=gamma)
        weights = Analytic_Weights(mmd, V).to(mmd)
        t[k] = np.sqrt(weights.reshape(-1).dot(mmd))
        torch.cuda.empty_cache()
        # mem_ratio()
    return t

In [33]:
def check_thresh_null(N, n_te2, models, sigmas, sigma0s, epsilons, device, dtype, gamma=2):
    s1_te2 = np.zeros([n_te2 * 2, d])
    s2_te2 = np.zeros([n_te2 * 2, d])
    num_meta_tasks = len(sigmas)
    H_u = np.zeros(N)
    T_u = np.zeros(N)
    M_u = np.zeros(N)
    Num_clusters = 2
    np.random.seed(1102)
    N1_te2 = n_te2*2
    N_f = N*1.0
    kk = 0
    count_u = 0
    n = 50
    t = torch.zeros(N, dtype = dtype)
    m = torch.zeros(N, dtype = dtype)
    v = torch.zeros(N, dtype = dtype)
    for model in models:
        model.eval()

    for k in range(N):
        # Generate target tasks
        for i in range(Num_clusters):
            np.random.seed(seed=1102 * (k+2) + 2*kk + i + n)
            s1_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        for i in range(Num_clusters):
            np.random.seed(seed=819 * (k + 1) + 2*kk + i + n)
            s2_te2[n_te2 * (i):n_te2 * (i + 1), :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n_te2)
        S = np.concatenate((s1_te2, s2_te2), axis=0)
        S = MatConvert(S, device, dtype)
        Feas = []
        for i in range(num_meta_tasks):
            Fea_t = models[i](S)
            Feas.append(Fea_t.detach())
        # Run two sample test (deep kernel) on generated data
        mmd, V, _ = MMDg(Feas, N1_te2, S, sigmas, sigma0s, epsilons, alpha, device, dtype, gamma=gamma)
        # m[k] = mmd
        # v[k] = V
        # m = np.sqrt(N1_te2)*mmd
        # v = N1_te2*V
        weights, mmd, _ = Analytic_Weights(mmd, V)
        # weights = Analytic_Weights(m, v).to(m)
        # import IPython; IPython.embed()
        t[k] = torch.sqrt(weights.to(mmd).reshape(-1).dot(mmd))
        # t[k] = weights.reshape(-1).dot(m)
        torch.cuda.empty_cache()
        
        # mem_ratio()
    return t, mmd.shape[0]
    # return t, m, v

In [34]:
# Temp = get_weights(100, models, sigmas, sigma0s, epsilons, device, dtype,gamma= gamma)

In [35]:
import gc
count = 0
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            count += 1
            print(type(obj), obj.size())
    except:
        pass
print(count)

0




In [36]:
#### sigmas[0].shape

In [37]:
# Temp1 = get_weights1(100, models, sigmas, sigma0s, epsilons, device, dtype,gamma=gamma)

In [38]:
# estimate_type1(Temp1, N, N_per, n_te2, models, sigmas, sigma0s, epsilons, device, dtype,gamma=gamma)

In [39]:
# Temp = get_weights(100, models, sigmas, sigma0s, epsilons, device, dtype,gamma=gamma)

In [40]:
from scipy.stats import chisquare
import scipy.stats as stats

In [41]:
# t_obs = check_thresh(1000, n_te2, models, sigmas, sigma0s, epsilons, device, dtype,gamma=gamma)

In [42]:
# df = 100
# alpha = 0.05
# x = np.linspace(stats.chi.ppf(0.001,df),
#                 stats.chi.ppf(0.999,df), 1000)
# rv = stats.chi(df)
# plt.plot(x, rv.pdf(x), color='r')
# plt.vlines(np.quantile(np.array(t_obs),1-alpha),0,1,color = 'b')
# plt.vlines(rv.ppf(1-alpha),0,1,color = 'r')
# plt.hist(np.array(t_obs),bins = 100, density=True)
# plt.savefig('temp.png')

In [43]:
# t_obs_null = check_thresh_null(1000, 2000, models[:7], sigmas[:7], sigma0s[:7], epsilons[:7], device, dtype,gamma=gamma)

In [44]:
# df = 6
# alpha = 0.05
# x = np.linspace(stats.chi.ppf(0.001,df),
#                 stats.chi.ppf(0.999,df), 1000)
# rv = stats.chi(df)
# plt.plot(x, rv.pdf(x), color='r')
# plt.vlines(np.quantile(np.array(t_obs_null),1-alpha),0,1,color = 'b')
# plt.vlines(rv.ppf(1-alpha),0,1,color = 'r')
# plt.hist(np.array(t_obs_null),bins = 30, density=True)
# plt.savefig('temp_null.png')

In [45]:
gamma=1.9999

In [46]:
# power_wald = estimate_power_wald(1000, 250, models, sigmas, sigma0s, epsilons, device, dtype,gamma=1.9999)

In [47]:
# power_ost = estimate_power_ost(1000, 400, models, sigmas, sigma0s, epsilons, device, dtype,gamma=1.9999)

In [48]:
# power_ost

In [49]:
# type1_wald = estimate_type1_wald(1000, 250, models, sigmas, sigma0s, epsilons, device, dtype,gamma=gamma)

In [50]:
# type1_ost = estimate_type1_ost(1000, 250, models, sigmas, sigma0s, epsilons, device, dtype,gamma=gamma)

In [51]:
imgsz = 64
dataset_test = datasets.CIFAR10(root='./data/', download=False, train=False,
                                transform=transforms.Compose([
                                    transforms.Resize(imgsz),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000,
                                              shuffle=True, num_workers=1)
# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs

# Obtain CIFAR10.1 images
#data_new = np.load('./cifar10.1_v4_data.npy')
data_new = np.load('./data/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0, 3, 1, 2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(imgsz), transforms.ToTensor(),
                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T), 3, imgsz, imgsz])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)

In [52]:
config = [
        ('conv2d', [16, 3, 3, 3, 2, 1]),
        ('leakyrelu', [True]),
        ('conv2d', [32, 16, 3, 3, 2, 1]),
        ('leakyrelu', [True]),
        ('bn', [32]),
        ('conv2d', [64, 32, 3, 3, 2, 1]),
        ('leakyrelu', [True]),
        ('bn', [64]),
        ('conv2d', [128, 64, 3, 3, 2, 1]),
        ('leakyrelu', [True]),
        ('bn', [128]),
        ('flatten', []),
        ]
config.append(('linear', [300, 2048]))
num_meta_tasks = 45

In [53]:
models_c, sigmas_c, sigma0s_c, epsilons_c = load_kernels_cifar(num_meta_tasks,config)

In [54]:
import gc
count = 0
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            count += 1
            print(type(obj), obj.size())
    except:
        pass
print(count)

<class 'torch.nn.parameter.Parameter'> torch.Size([16, 3, 3, 3])
<class 'torch.nn.parameter.Parameter'> torch.Size([16])
<class 'torch.nn.parameter.Parameter'> torch.Size([32, 16, 3, 3])
<class 'torch.nn.parameter.Parameter'> torch.Size([32])
<class 'torch.nn.parameter.Parameter'> torch.Size([32])
<class 'torch.nn.parameter.Parameter'> torch.Size([32])
<class 'torch.nn.parameter.Parameter'> torch.Size([32])
<class 'torch.nn.parameter.Parameter'> torch.Size([32])
<class 'torch.nn.parameter.Parameter'> torch.Size([64, 32, 3, 3])
<class 'torch.nn.parameter.Parameter'> torch.Size([64])
<class 'torch.nn.parameter.Parameter'> torch.Size([64])
<class 'torch.nn.parameter.Parameter'> torch.Size([64])
<class 'torch.nn.parameter.Parameter'> torch.Size([64])
<class 'torch.nn.parameter.Parameter'> torch.Size([64])
<class 'torch.nn.parameter.Parameter'> torch.Size([128, 64, 3, 3])
<class 'torch.nn.parameter.Parameter'> torch.Size([128])
<class 'torch.nn.parameter.Parameter'> torch.Size([128])
<class

In [55]:
data_new.shape

(2021, 32, 32, 3)

In [56]:
# power_ost_cifar = estimate_power_ost_draft(1000, 250, models_c, sigmas_c, sigma0s_c, epsilons_c, device, dtype,gamma=1.9999)

In [None]:
power_ost_cifar = estimate_power_ost_draft(1000, 300, models_c, sigmas_c, sigma0s_c, epsilons_c, device, dtype,gamma=1.999)

In [None]:
power_ost_cifar

In [None]:
# power_ost_cifar = estimate_power_wald_draft(1000, 300, models_c, sigmas_c, sigma0s_c, epsilons_c, device, dtype,gamma=1.999)

In [None]:
power_ost_cifar = estimate_power_ost_draft(1000, 400, models_c, sigmas_c, sigma0s_c, epsilons_c, device, dtype,gamma=1.9999)

In [None]:
power_ost_cifar = estimate_power_ost_draft(1000, 600, models_c, sigmas_c, sigma0s_c, epsilons_c, device, dtype,gamma=1.99999)

In [None]:
power_ost_cifar = estimate_power_ost_draft(1000, 700, models_c, sigmas_c, sigma0s_c, epsilons_c, device, dtype,gamma=1.9999)

In [None]:
power_ost_cifar = estimate_power_ost_draft(1000, 1000, models_c, sigmas_c, sigma0s_c, epsilons_c, device, dtype,gamma=1.9999)

In [None]:
power_ost_cifar = estimate_power_ost_draft(1000, 1100, models_c, sigmas_c, sigma0s_c, epsilons_c, device, dtype,gamma=1.9999)

In [None]:
type1_ost_cifar = estimate_type1_ost_draft(1000, 250, models_c, sigmas_c, sigma0s_c, epsilons_c, device, dtype,gamma=1.9999)

In [None]:
class MMDLoss:
    def __init__(self, N, N_per, n_te, models, sigmas, sigma0s, epsilons):
        num_kernels = model.shape[0]
        self.weights = np.ones(num_kernels)
        self.N = N
        self.N_per = N_per
        self.n_te = net
        self.models = models
        self.sigmas = sigmas
        self.sigma0s = sigma0s
        self.epsilons = epsilons
        def fnet_sum_single(params, x):
            x = x.unsqueeze(0)
            f_sum = self.fnet_sum(params, x)
            f_sum = f_sum.squeeze(0)
            return f_sum
            # return self.fnet_sum(params, x).squeeze(0)

        def fnet_pred_single(params, x):
            x = x.unsqueeze(0)
            f_pred = self.fnet_pred(params, x)
            f_pred = f_pred.squeeze(0)
            return f_pred
            # return self.fnet_pred(params, x).squeeze(0)

        self.fnet_sum_single = fnet_sum_single
        self.fnet_pred_single = fnet_pred_single

    def masking_loss(self, x_train):
        full_pred = vmap(self.fnet_pred_single, (None, 0))(self.net_params, x_train)
        masked_pred = vmap(self.fnet_pred_single, (None, 0))(self.masked_params, x_train)

        full_ntk = empirical_ntk_batch_cross(self.fnet_sum_single, self.net_params, x_train)
        masked_ntk = empirical_ntk_batch_cross(self.fnet_sum_single, self.masked_params, x_train)

        batch_size = x_train.shape[0]
        pred_term = (1. / batch_size) * pt.sum((full_pred - masked_pred) ** 2)
        ntk_term = (self.gamma_1 / batch_size) ** 2 * pt.sum((full_ntk - masked_ntk) ** 2)
        mask_term = self.gamma_2 * pt.sum(pt.stack([pt.sum(pt.abs(m)) for m in self.masks]))

        loss = pred_term + ntk_term + mask_term
        return loss

In [None]:
V = torch.Tensor([[0.,-3.,-2.],[1.,-4.,-2.],[-3.,4.,1.]])
V = V@V.T
# l = torch.linalg.cholesky(V)
tau = torch.Tensor([3.,2.,1.])
# torch.cholesky_solve(tau.reshape(-1,1), l)

In [None]:
def Analytic_Weights(mmd, V):
    l = mmd.shape[0]
    epsilon=1e-5
    I = torch.eye(l,dtype=dtype, device = device)
    V = V.to(I)
    mmd = mmd.to(I)
    L = torch.linalg.cholesky(V+epsilon*I)
    beta = torch.cholesky_solve(mmd.reshape(-1,1), L)
    return beta.to(I)

In [None]:
res = Analytic_Weights(tau, V)

In [None]:
V.to(res)@res